remove transformer block, implement as sequential of size 1 (#54)

* remove transformer block, implement as sequence size 1
* reimplement get_remote_module
* fix readme

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
pull/57/head
Pavel Samygin 2 years ago committed by GitHub
parent 77220c718c
commit 0be21775af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,18 +37,18 @@ Then open a python notebook or console and run:
```python
import torch
import hivemind
from src import get_remote_module
from src import DistributedBloomConfig, get_remote_module
dht = hivemind.DHT(
initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/...
client_mode=True, start=True,
)
layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'])
config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3")
layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config)
assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
# test forward/backward, two blocks
outputs, = layer4(*layer3(torch.randn(1, 64, 4096)))
outputs = layer4(layer3(torch.randn(1, 64, 4096)))
loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward()
@ -74,18 +74,18 @@ python -m cli.convert_model --model bigscience/bloom-6b3 \
To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables:
```bash
# shell A: serve blocks 3 and 4
# shell A: serve model
python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
--block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
--torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
# shell B: connect to the swarm and test individual blocks for exact match
export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
BLOCK_UID=bigscience/test-bloomd-6b3.3 pytest tests/test_block_exact_match.py
BLOCK_UID=bigscience/test-bloomd-6b3.4 pytest tests/test_block_exact_match.py
# shell B:
export PYTHONPATH=.
export INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
export MODEL_NAME="bigscience/test-bloomd-6b3"
# the test below will fail because there is no server that serves layer 7
# BLOCK_UID=bigscience/test-bloomd-6b3.7 pytest tests/test_block_exact_match.py
# test individual random blocks for exact match
pytest tests/test_block_exact_match.py
# test the full model (requires that servers collectively serve all model layers)
REF_NAME=bigscience/bloom-6b3 pytest tests/test_full_model.py
# test the full model
pytest tests/test_full_model.py
```

@ -1,5 +1,4 @@
from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager

@ -1,40 +0,0 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
from __future__ import annotations
import random
import torch
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.utils import get_logger, use_hivemind_log_handler
from src.client.inference_session import RemoteTransformerBlockInferenceSession
from src.data_structures import RemoteModuleInfo
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteTransformerBlock(RemoteExpert):
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys()))) # TODO replace this
super().__init__(peer_info, p2p)
@property
def stub(self) -> StubBase:
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
def forward(self, inputs: torch.Tensor, **kwargs):
for k, v in kwargs.items():
assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
return super().forward(inputs)
def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
return RemoteExpertWorker.run_coroutine(
RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
)

@ -1,6 +1,5 @@
from __future__ import annotations
import logging
from typing import Optional, Union
import torch
@ -10,11 +9,9 @@ from torch import nn
import src
from src.client.inference_session import RemoteSequentialInferenceSession
from src.client.remote_block import RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
from src.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger")
@ -57,12 +54,16 @@ class RemoteSequential(nn.Module):
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return outputs
def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
assert isinstance(ix, (int, slice))
if isinstance(ix, int):
assert 0 <= ix < len(self)
(module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
return module
return RemoteTransformerBlock(
self.config,
self.dht,
dht_prefix=self.dht_prefix,
p2p=self.p2p,
sequence_manager=self.sequence_manager[ix],
)
else:
return RemoteSequential(
self.config,
@ -85,3 +86,18 @@ class RemoteSequential(nn.Module):
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
class RemoteTransformerBlock(RemoteSequential):
"""Single transformer block hosted by swarm
This class is deprecated and kept for backward compatibility.
It will be removed soon in favor of using ``RemoteSequential`` directly.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(self) == 1, "Remote Block is a sequence size 1"
def extra_repr(self):
return f"{self.sequence_manager.block_uids[0]}"

@ -82,6 +82,7 @@ class RemoteSequenceManager:
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.warning(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
if not info.servers:
@ -95,22 +96,24 @@ class RemoteSequenceManager:
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
for peer_id, server in info.servers.items():
if server.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
if info is not None:
for peer_id, server in info.servers.items():
if server.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
peer_id not in info.servers
info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)

@ -110,7 +110,7 @@ async def sequential_forward(
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
"""
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)

@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P, PeerID
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src
@ -72,34 +72,63 @@ async def _declare_active_modules(
)
def get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False,
) -> Union[src.RemoteSequential, MPFuture]:
return RemoteExpertWorker.run_coroutine(
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
)
async def _get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> src.RemoteSequential:
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
p2p = await dht.replicate_p2p()
manager = src.RemoteSequenceManager(dht, uids, p2p)
return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
def get_remote_module(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
expiration_time: Optional[DHTExpiration] = None,
config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False,
) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
"""
:param uid_or_uids: find one or more modules with these ids from across the DHT
:param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
:returns: a list of [RemoteTransformerBlock if found else None]
:returns: a list of [RemoteTransformerBlock]
"""
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
infos = dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
return RemoteExpertWorker.run_coroutine(
_get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
)
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
p2p = await dht.replicate_p2p()
modules = _create_remote_modules_from_infos(await infos_future, p2p)
return modules[0] if single_uid else modules
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
modules = _create_remote_modules_from_infos(infos, p2p)
async def _get_remote_module(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
p2p = await dht.replicate_p2p()
managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
modules = [
src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
]
return modules[0] if single_uid else modules
@ -149,15 +178,3 @@ async def _get_remote_module_infos(
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules
def _create_remote_modules_from_infos(
infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
) -> List[Optional[src.RemoteTransformerBlock]]:
modules: List[Optional[src.RemoteTransformerBlock]] = []
for info in infos:
if info is not None:
modules.append(src.RemoteTransformerBlock(info, p2p))
else:
modules.append(None)
return modules

@ -7,8 +7,10 @@ import transformers
from hivemind import P2PHandlerError
from test_utils import *
import src
from src import DistributedBloomConfig
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_sequential import RemoteTransformerBlock
from src.data_structures import UID_DELIMITER
from src.dht_utils import get_remote_module
@ -16,16 +18,14 @@ from src.dht_utils import get_remote_module
@pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
for block_index in random.sample(range(config.n_layer), 3):
block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
remote_block = get_remote_module(dht, block_uid)
assert remote_block is not None, f"Could not find {block_uid} in DHT"
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
assert isinstance(remote_block, RemoteTransformerBlock)
inputs = torch.randn(1, 8, config.hidden_size)
(outputs_forward,) = remote_block(inputs)
outputs_forward = remote_block(inputs)
outputs_inference = []
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:

@ -7,25 +7,20 @@
import hivemind
import pytest
import torch
import transformers
from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
from test_utils import *
import src
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
from src.client.remote_sequential import RemoteSequential
from src.dht_utils import get_remote_sequence
@pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_blocks = get_remote_sequence(dht, 3, 6, config)
assert isinstance(remote_blocks, RemoteSequential)
ref_blocks = [
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
@ -33,7 +28,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
]
inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
outputs_rpc = remote_block.forward(inputs)[0]
outputs_rpc = remote_blocks.forward(inputs)
outputs_rpc.sum().backward()
grads_rpc = inputs.grad
@ -52,18 +47,14 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_blocks = get_remote_sequence(dht, 3, 5, config)
assert isinstance(remote_blocks, RemoteSequential)
inputs = torch.randn(1, 8, config.hidden_size)
outputs_inference = []
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)

Loading…
Cancel
Save