remove transformer block, implement as sequential of size 1 (#54)

* remove transformer block, implement as sequence size 1
* reimplement get_remote_module
* fix readme

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
pull/57/head
Pavel Samygin 2 years ago committed by GitHub
parent 77220c718c
commit 0be21775af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,18 +37,18 @@ Then open a python notebook or console and run:
```python ```python
import torch import torch
import hivemind import hivemind
from src import get_remote_module from src import DistributedBloomConfig, get_remote_module
dht = hivemind.DHT( dht = hivemind.DHT(
initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/... initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/...
client_mode=True, start=True, client_mode=True, start=True,
) )
config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3")
layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4']) layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config)
assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT" assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
# test forward/backward, two blocks # test forward/backward, two blocks
outputs, = layer4(*layer3(torch.randn(1, 64, 4096))) outputs = layer4(layer3(torch.randn(1, 64, 4096)))
loss = (outputs * torch.randn_like(outputs)).norm() loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward() loss.backward()
@ -74,18 +74,18 @@ python -m cli.convert_model --model bigscience/bloom-6b3 \
To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables: To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables:
```bash ```bash
# shell A: serve blocks 3 and 4 # shell A: serve model
python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \ python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
--block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
# shell B: connect to the swarm and test individual blocks for exact match # shell B:
export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT" export PYTHONPATH=.
BLOCK_UID=bigscience/test-bloomd-6b3.3 pytest tests/test_block_exact_match.py export INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
BLOCK_UID=bigscience/test-bloomd-6b3.4 pytest tests/test_block_exact_match.py export MODEL_NAME="bigscience/test-bloomd-6b3"
# the test below will fail because there is no server that serves layer 7 # test individual random blocks for exact match
# BLOCK_UID=bigscience/test-bloomd-6b3.7 pytest tests/test_block_exact_match.py pytest tests/test_block_exact_match.py
# test the full model (requires that servers collectively serve all model layers) # test the full model
REF_NAME=bigscience/bloom-6b3 pytest tests/test_full_model.py pytest tests/test_full_model.py
``` ```

@ -1,5 +1,4 @@
from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager from src.client.sequence_manager import RemoteSequenceManager

@ -1,40 +0,0 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
from __future__ import annotations
import random
import torch
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.utils import get_logger, use_hivemind_log_handler
from src.client.inference_session import RemoteTransformerBlockInferenceSession
from src.data_structures import RemoteModuleInfo
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteTransformerBlock(RemoteExpert):
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys()))) # TODO replace this
super().__init__(peer_info, p2p)
@property
def stub(self) -> StubBase:
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
def forward(self, inputs: torch.Tensor, **kwargs):
for k, v in kwargs.items():
assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
return super().forward(inputs)
def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
return RemoteExpertWorker.run_coroutine(
RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
)

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -10,11 +9,9 @@ from torch import nn
import src import src
from src.client.inference_session import RemoteSequentialInferenceSession from src.client.inference_session import RemoteSequentialInferenceSession
from src.client.remote_block import RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager from src.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
from src.utils.misc import DUMMY from src.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
@ -57,12 +54,16 @@ class RemoteSequential(nn.Module):
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return outputs return outputs
def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]: def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
assert isinstance(ix, (int, slice)) assert isinstance(ix, (int, slice))
if isinstance(ix, int): if isinstance(ix, int):
assert 0 <= ix < len(self) return RemoteTransformerBlock(
(module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p) self.config,
return module self.dht,
dht_prefix=self.dht_prefix,
p2p=self.p2p,
sequence_manager=self.sequence_manager[ix],
)
else: else:
return RemoteSequential( return RemoteSequential(
self.config, self.config,
@ -85,3 +86,18 @@ class RemoteSequential(nn.Module):
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
class RemoteTransformerBlock(RemoteSequential):
"""Single transformer block hosted by swarm
This class is deprecated and kept for backward compatibility.
It will be removed soon in favor of using ``RemoteSequential`` directly.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(self) == 1, "Remote Block is a sequence size 1"
def extra_repr(self):
return f"{self.sequence_manager.block_uids[0]}"

@ -82,6 +82,7 @@ class RemoteSequenceManager:
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None: if info is None:
logger.warning(f"Found no block info for block {uid}") logger.warning(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo): if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}") logger.warning(f"Unexpected dht entry type for {uid}: {info}")
if not info.servers: if not info.servers:
@ -95,22 +96,24 @@ class RemoteSequenceManager:
closed_spans = [] closed_spans = []
active_spans = {} active_spans = {}
for block_index, info in enumerate(block_infos): for block_index, info in enumerate(block_infos):
for peer_id, server in info.servers.items(): if info is not None:
if server.state != ServerState.ONLINE: for peer_id, server in info.servers.items():
continue if server.state != ServerState.ONLINE:
if peer_id not in active_spans: continue
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) if peer_id not in active_spans:
else: # peer_id in active_spans active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
active_spans[peer_id].end = block_index + 1 else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()): for peer_id in list(active_spans.keys()):
if ( if (
peer_id not in info.servers info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1 or block_index == len(block_infos) - 1
): ):
closed_spans.append(active_spans.pop(peer_id)) closed_spans.append(active_spans.pop(peer_id))
assert not active_spans assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)

@ -110,7 +110,7 @@ async def sequential_forward(
If some subsequence fails, reconstructs the remaining path and tries to finish the forward. If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
""" """
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
end_index = end_index if end_index is not None else len(sequence_manager.block_uids) end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)

@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P, PeerID from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src import src
@ -72,34 +72,63 @@ async def _declare_active_modules(
) )
def get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False,
) -> Union[src.RemoteSequential, MPFuture]:
return RemoteExpertWorker.run_coroutine(
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
)
async def _get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> src.RemoteSequential:
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
p2p = await dht.replicate_p2p()
manager = src.RemoteSequenceManager(dht, uids, p2p)
return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
def get_remote_module( def get_remote_module(
dht: DHT, dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]], uid_or_uids: Union[ModuleUID, List[ModuleUID]],
expiration_time: Optional[DHTExpiration] = None, config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False, return_future: bool = False,
) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]: ) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
""" """
:param uid_or_uids: find one or more modules with these ids from across the DHT :param uid_or_uids: find one or more modules with these ids from across the DHT
:param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time) :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
:returns: a list of [RemoteTransformerBlock if found else None] :returns: a list of [RemoteTransformerBlock]
""" """
single_uid = isinstance(uid_or_uids, ModuleUID) return RemoteExpertWorker.run_coroutine(
uids = [uid_or_uids] if single_uid else uid_or_uids _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
infos = dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
) )
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
p2p = await dht.replicate_p2p()
modules = _create_remote_modules_from_infos(await infos_future, p2p)
return modules[0] if single_uid else modules
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future) async def _get_remote_module(
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) dht: DHT,
modules = _create_remote_modules_from_infos(infos, p2p) uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
p2p = await dht.replicate_p2p()
managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
modules = [
src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
]
return modules[0] if single_uid else modules return modules[0] if single_uid else modules
@ -149,15 +178,3 @@ async def _get_remote_module_infos(
if servers: if servers:
modules[i] = RemoteModuleInfo(uid, servers) modules[i] = RemoteModuleInfo(uid, servers)
return modules return modules
def _create_remote_modules_from_infos(
infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
) -> List[Optional[src.RemoteTransformerBlock]]:
modules: List[Optional[src.RemoteTransformerBlock]] = []
for info in infos:
if info is not None:
modules.append(src.RemoteTransformerBlock(info, p2p))
else:
modules.append(None)
return modules

@ -7,8 +7,10 @@ import transformers
from hivemind import P2PHandlerError from hivemind import P2PHandlerError
from test_utils import * from test_utils import *
import src
from src import DistributedBloomConfig
from src.bloom.from_pretrained import load_pretrained_block from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock from src.client.remote_sequential import RemoteTransformerBlock
from src.data_structures import UID_DELIMITER from src.data_structures import UID_DELIMITER
from src.dht_utils import get_remote_module from src.dht_utils import get_remote_module
@ -16,16 +18,14 @@ from src.dht_utils import get_remote_module
@pytest.mark.forked @pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME) config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
for block_index in random.sample(range(config.n_layer), 3): for block_index in random.sample(range(config.n_layer), 3):
block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}" remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
remote_block = get_remote_module(dht, block_uid)
assert remote_block is not None, f"Could not find {block_uid} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock) assert isinstance(remote_block, RemoteTransformerBlock)
inputs = torch.randn(1, 8, config.hidden_size) inputs = torch.randn(1, 8, config.hidden_size)
(outputs_forward,) = remote_block(inputs) outputs_forward = remote_block(inputs)
outputs_inference = [] outputs_inference = []
with remote_block.inference_session(max_length=inputs.shape[1]) as sess: with remote_block.inference_session(max_length=inputs.shape[1]) as sess:

@ -7,25 +7,20 @@
import hivemind import hivemind
import pytest import pytest
import torch import torch
import transformers
from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
from test_utils import * from test_utils import *
import src
from src.bloom.from_pretrained import load_pretrained_block from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock from src.client.remote_sequential import RemoteSequential
from src.dht_utils import get_remote_module from src.dht_utils import get_remote_sequence
@pytest.mark.forked @pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME) config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") remote_blocks = get_remote_sequence(dht, 3, 6, config)
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" assert isinstance(remote_blocks, RemoteSequential)
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
ref_blocks = [ ref_blocks = [
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
@ -33,7 +28,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32), load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
] ]
inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True) inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
outputs_rpc = remote_block.forward(inputs)[0] outputs_rpc = remote_blocks.forward(inputs)
outputs_rpc.sum().backward() outputs_rpc.sum().backward()
grads_rpc = inputs.grad grads_rpc = inputs.grad
@ -52,18 +47,14 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked @pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4): def test_chained_inference_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME) config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") remote_blocks = get_remote_sequence(dht, 3, 5, config)
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" assert isinstance(remote_blocks, RemoteSequential)
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
inputs = torch.randn(1, 8, config.hidden_size) inputs = torch.randn(1, 8, config.hidden_size)
outputs_inference = [] outputs_inference = []
with remote_block.inference_session(max_length=inputs.shape[1]) as sess: with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]): for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1) outputs_inference = torch.cat(outputs_inference, dim=1)

Loading…
Cancel
Save