Implement RemoteSequential slicing and extra repr, add tests (#30)

- finish renaming RemoteSequenceInfo -> RemoteSequenceManager (why: if it was an *Info, user would expect it to be similar - to a dataclass; whereas in actuality, the class is doing heavy network interactions on its own)
- implement RemoteSequenceManager.make_sequence (from https://pastebin.com/uXgy2U8B )
- make RemoteSequentialInferenceSession use RemoteSequenceManager.make_sequence
- make tests pass again
- make it possible to create inference session without RemoteTransformerBlock
- make a standalone test for RemoteSequential
- rollback convert-model

Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com>
pull/32/head
justheuristic 2 years ago committed by GitHub
parent 6ee942e915
commit f0c7383181
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,6 +66,8 @@ jobs:
run: |
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
export REF_NAME=bigscience/bloom-350m
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
SERVER1_PID=$!
@ -79,14 +81,7 @@ jobs:
sleep 30 # wait for server to download layers
# test individual blocks
export PYTHONPATH=.
BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py
BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py
REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py
REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
PYTHONPATH=. pytest tests
kill -s SIGINT $SERVER1_PID $SERVER2_PID
echo "Done!"

@ -1,4 +1,5 @@
from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential
from src.client.sequence_manager import RemoteSequenceManager

@ -0,0 +1,173 @@
from __future__ import annotations
import asyncio
import contextlib
from typing import AsyncIterator, List, Optional
import torch
from hivemind import (
P2P,
anext,
deserialize_torch_tensor,
get_logger,
nested_flatten,
serialize_torch_tensor,
use_hivemind_log_handler,
)
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteTransformerBlockInferenceSession:
"""
An interface to a single multi-step *inference* session for a specific remote module on a specific server
:note: this inference session is *not* fault-tolerant out of the box
"""
def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
self.uid, self.rpc_info = uid, rpc_info
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.stepped = False
self.closed = False
@classmethod
async def _create(
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
) -> RemoteTransformerBlockInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
inputs_queue = asyncio.Queue()
outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
return cls(uid, rpc_info, inputs_queue, outputs_stream)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
while True:
next_input_message = await asyncio.wait_for(queue.get(), timeout)
yield next_input_message
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
def step(self, new_hidden_states: torch.Tensor):
"""Inference step: send a chunk of input tensors and receive a chunk of outputs"""
if self.closed:
raise Exception("Session is closed, cannot perform step")
# serialize inputs and put them into the queue
inputs = (new_hidden_states,)
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor, proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
],
)
)
)
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
return outputs[0]
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await anext(self._outputs_stream)
def close(self):
"""Finish a given inference session, close the underlying connection"""
if self._outputs_stream is None:
return # already closed
RemoteExpertWorker.run_coroutine(self._aclose_stream())
self._outputs_stream = self._inputs_queue = None
self.closed = True
async def _aclose_stream(self):
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
if self._outputs_stream is None:
return # already closed
if self.stepped:
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
def __del__(self):
self.close()
def __enter__(self):
assert not self.closed
return self
def __exit__(self, *exc_details):
self.close()
class RemoteSequentialInferenceSession:
"""
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
self.sequence_manager = sequence_manager
self.p2p = p2p
self.closed = False
self.chosen_spans: List[RemoteSpanInfo] = []
self.stack = contextlib.ExitStack()
self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
self.timeout = timeout
def __enter__(self):
assert not self.closed and not self.chosen_spans
self.stack.__enter__()
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
self.chosen_spans.extend(self.sequence_manager.make_sequence())
for chosen_span in self.chosen_spans:
stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
inference_session = RemoteExpertWorker.run_coroutine(
RemoteTransformerBlockInferenceSession._create(
stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
)
)
self.inference_sessions.append(inference_session)
self.stack.enter_context(inference_session)
return self
def step(self, inputs: torch.Tensor):
assert not self.closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
for session in self.inference_sessions:
outputs = session.step(inputs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self.closed:
self.stack.__exit__(*exc_details or (None, None, None))
self.inference_sessions.clear()
self.closed = True
def __exit__(self, *exc_details):
self.close(*exc_details)
def __del__(self):
self.close()

@ -1,20 +1,16 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
from __future__ import annotations
import asyncio
import random
from typing import Any, AsyncIterator, Dict, Optional
import torch
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
from hivemind.utils import get_logger, use_hivemind_log_handler
from src.client.inference_session import RemoteTransformerBlockInferenceSession
from src.data_structures import RemoteModuleInfo
from src.dht_utils import ModuleUID
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
@ -39,100 +35,10 @@ class RemoteTransformerBlock(RemoteExpert):
def inference_session(self) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
return RemoteExpertWorker.run_coroutine(
RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
)
def begin_inference_session(self):
logger.warning("beging_inference_session was renamed to just inference_session")
return self.inference_session()
class RemoteTransformerBlockInferenceSession:
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
self.uid, self.info = uid, info
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.stepped = False
self.closed = False
@classmethod
async def _create(
cls,
remote_module: RemoteTransformerBlock,
timeout: Optional[float] = None,
) -> RemoteTransformerBlockInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
inputs_queue = asyncio.Queue()
outputs_stream = await remote_module.stub.rpc_inference(
cls._read_inputs_from_queue(inputs_queue, timeout),
timeout=timeout,
)
return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
while True:
next_input_message = await asyncio.wait_for(queue.get(), timeout)
yield next_input_message
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
def step(self, new_hidden_states: torch.Tensor):
"""Inference step: send a chunk of input tensors and receive a chunk of outputs"""
if self.closed:
raise Exception("Session is closed, cannot perform step")
# serialize inputs and put them into the queue
inputs = (new_hidden_states,)
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor, proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
],
)
)
)
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
return outputs[0]
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await anext(self._outputs_stream)
def close(self):
"""Finish a given inference session, close the underlying connection"""
if self._outputs_stream is None:
return # already closed
RemoteExpertWorker.run_coroutine(self._aclose_stream())
self._outputs_stream = self._inputs_queue = None
self.closed = True
async def _aclose_stream(self):
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
if self._outputs_stream is None:
return # already closed
if self.stepped:
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
def __del__(self):
self.close()
def __enter__(self):
assert not self.closed
return self
def __exit__(self, *exc_details):
self.close()

@ -1,17 +1,15 @@
from __future__ import annotations
import contextlib
import logging
import random
from typing import Optional, Union
import torch
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from torch import nn
import src
from src.client.inference_session import RemoteSequentialInferenceSession
from src.client.remote_block import RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import UID_DELIMITER
@ -30,49 +28,41 @@ class RemoteSequential(nn.Module):
self,
config: src.DistributedBloomConfig,
dht: DHT,
prefix: str,
max_retries: int = 3,
dht_prefix: Optional[str] = None,
p2p: Optional[P2P] = None,
sequence_manager: Optional[RemoteSequenceManager] = None,
):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
if prefix.endswith(UID_DELIMITER):
logger.warning(
f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
f"This will cause {self.__class__.__name__} to look for modules under "
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
)
super().__init__()
self.config = config
self.dht = dht
self.prefix = prefix
self.max_retries = max_retries
self.dht_prefix = dht_prefix or config.dht_prefix
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
if sequence_manager is None:
logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
self.sequence_manager = RemoteSequenceManager(dht, block_uids)
self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
self.is_subsequence = False
else:
logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
self.sequence_manager = sequence_manager
assert isinstance(sequence_manager.block_uids, list)
logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
self.is_subsequence = self.sequence_manager.block_uids == block_uids
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
for block_index in range(self.config.n_layer):
for retry_index in range(self.max_retries):
for block in iter(self):
for retry_index in range(self.sequence_manager.max_retries):
try:
block = self[block_index]
(outputs,) = block(inputs)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
inputs = outputs
break
except Exception as e:
if retry_index == self.max_retries - 1:
if retry_index == self.sequence_manager.max_retries - 1:
raise e
else:
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
@ -81,21 +71,20 @@ class RemoteSequential(nn.Module):
def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
assert isinstance(ix, (int, slice))
if isinstance(ix, int):
assert 0 <= ix < self.config.n_layer
assert 0 <= ix < len(self)
(module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
return module
else:
return RemoteSequential(
self.config,
self.dht,
prefix=self.prefix,
max_retries=self.max_retries,
dht_prefix=self.dht_prefix,
p2p=self.p2p,
sequence_manager=self.sequence_manager[ix],
)
def __iter__(self):
for block_index in range(self.config.n_layer):
for block_index in range(len(self)):
yield self[block_index]
def __len__(self):
@ -105,56 +94,5 @@ class RemoteSequential(nn.Module):
self.sequence_manager.update_()
return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
class RemoteSequentialInferenceSession:
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
self.remote_sequence_info = remote_sequence_info
self.p2p = p2p
self.closed = False
self.stack = contextlib.ExitStack()
self.active_sessions = []
def __enter__(self):
assert not self.closed
self.stack.__enter__()
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
current_block = 0
while current_block != len(self.remote_sequence_info):
candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
assert chosen_span.start <= current_block < chosen_span.end
# TODO begin throwaway prototype code
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
_ = remote.info # TODO fix
span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
self.active_sessions.append(remote.inference_session())
self.stack.enter_context(self.active_sessions[-1])
current_block = chosen_span.end
# TODO end throwaway prototype code
return self
def step(self, inputs: torch.Tensor):
assert not self.closed
for session in self.active_sessions:
outputs = session.step(inputs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self.closed:
self.stack.__exit__(*exc_details or (None, None, None))
self.active_sessions.clear()
self.closed = True
def __exit__(self, *exc_details):
self.close(*exc_details)
def __del__(self):
self.close()
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -1,36 +1,37 @@
from __future__ import annotations
import random
import threading
from typing import List, Optional, Sequence, Tuple, Union
from hivemind import DHT, DHTExpiration
from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from src.dht_utils import get_remote_module_infos
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteSequenceManager:
"""Keeps and updates the meta-information about which peers host which blocks"""
dht: DHT
block_uids: List[ModuleUID]
block_infos: List[Optional[RemoteModuleInfo]]
spans_by_priority: List[RemoteSpanInfo] # sorted from best to worst
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
last_update_time: DHTExpiration
lock_changes: threading.Lock
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
self.dht = dht
self.block_uids = list(block_uids)
self.block_infos = [None] * len(self.block_uids)
self.spans_by_priority = []
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
self.last_update_time = -float("inf")
"""
Keeps and updates the meta-information about which peers host which blocks.
In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
"""
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
self.dht, self.p2p = dht, p2p
self.block_uids: List[ModuleUID] = list(block_uids)
self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
self.last_update_time: DHTExpiration = -float("inf")
self.max_retries = max_retries
self._rpc_info = None
self.lock_changes = threading.Lock()
self.update_()
@ -38,13 +39,33 @@ class RemoteSequenceManager:
assert info is not None, f"Found no remote peers for block {uid}"
assert self.spans_by_priority and self.spans_containing_block
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
"""
end_index = end_index if end_index is not None else len(self.block_uids)
span_sequence = []
current_index = start_index
while current_index < end_index:
candidate_spans = self.spans_containing_block[current_index]
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
assert chosen_span.start <= current_index < chosen_span.end
span_sequence.append(chosen_span)
current_index = chosen_span.end
return span_sequence
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
assert isinstance(ix, (int, slice))
if not isinstance(ix, slice):
ix = slice(int(ix), int(ix) + 1, 1)
with self.lock_changes:
subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
subseq.block_infos = self.block_infos[ix]
subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
subseq.last_update_time = self.last_update_time
@ -102,3 +123,25 @@ class RemoteSequenceManager:
def __len__(self):
return len(self.block_uids)
@property
def rpc_info(self):
"""Return the rpc_info queried from one of the servers that hold the first block"""
if self._rpc_info is None:
retries = 0
for i in range(self.max_retries):
try:
self.update_()
peer_id = random.choice(list(self.block_infos[0].servers.keys()))
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
outputs = RemoteExpertWorker.run_coroutine(
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
)
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
except Exception as e:
retries += 1
if retries >= self.max_retries:
raise e
else:
logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
return self._rpc_info

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict
from typing import Any, Dict
from hivemind import PeerID
@ -36,3 +36,6 @@ class RemoteSpanInfo:
start: int
end: int
peer_id: PeerID
RPCInfo = Dict[str, Any]

@ -1,4 +1,3 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
import contextlib
from typing import AsyncIterator, Dict, Sequence

@ -0,0 +1,51 @@
import asyncio
import gc
from contextlib import suppress
import psutil
import pytest
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.mpfuture import MPFuture
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
@pytest.fixture
def event_loop():
"""
This overrides the ``event_loop`` fixture from pytest-asyncio
(e.g. to make it compatible with ``asyncio.subprocess``).
This fixture is identical to the original one but does not call ``loop.close()`` in the end.
Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
fails if the loop is closed, but works if the loop is only stopped).
"""
yield asyncio.get_event_loop()
@pytest.fixture(autouse=True, scope="session")
def cleanup_children():
yield
with RSAPrivateKey._process_wide_key_lock:
RSAPrivateKey._process_wide_key = None
gc.collect() # Call .__del__() for removed objects
children = psutil.Process().children(recursive=True)
if children:
logger.info(f"Cleaning up {len(children)} leftover child processes")
for child in children:
with suppress(psutil.NoSuchProcess):
child.terminate()
psutil.wait_procs(children, timeout=1)
for child in children:
with suppress(psutil.NoSuchProcess):
child.kill()
MPFuture.reset_backend()

@ -1,47 +1,39 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
import os
import random
import hivemind
import pytest
import torch
import transformers
from test_utils import *
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.data_structures import UID_DELIMITER
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
BLOCK_UID = os.environ.get("BLOCK_UID")
if not BLOCK_UID:
raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
@pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
remote_block = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
for block_index in random.sample(range(config.n_layer), 3):
block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
remote_block = get_remote_module(dht, block_uid)
assert remote_block is not None, f"Could not find {block_uid} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
inputs = torch.randn(1, 8, ref_config.hidden_size)
(outputs_forward,) = remote_block(inputs)
inputs = torch.randn(1, 8, config.hidden_size)
(outputs_forward,) = remote_block(inputs)
outputs_inference = []
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
outputs_inference = []
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
ref_block = load_pretrained_block(REF_NAME, REF_INDEX, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

@ -3,30 +3,20 @@
# - if you want more stable tests, see test_block_exact_match
# - if you want to figure out chained inference, ask yozh
import os
import hivemind
import pytest
import torch
import transformers
from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
from test_utils import *
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
MODEL_NAME = os.environ.get("MODEL_NAME")
if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
@pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
@ -38,9 +28,9 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
]
inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
outputs_rpc = remote_block.forward(inputs)[0]
@ -59,6 +49,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
@pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
@ -78,8 +69,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
outputs_inference = torch.cat(outputs_inference, dim=1)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
]
outputs_ref = []
caches = [None, None]

@ -1,9 +1,8 @@
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
import os
import pytest
import torch
import transformers
from hivemind import get_logger, use_hivemind_log_handler
from test_utils import *
from src.client.remote_model import DistributedBloomForCausalLM
@ -11,19 +10,7 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
MODEL_NAME = os.environ.get("MODEL_NAME")
if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME")
@pytest.mark.forked
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
@ -31,23 +18,12 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
assert len(model.transformer.h) == model.config.n_layer
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
parallel_outputs = model.forward(test_inputs).logits
assert torch.all(torch.isfinite(parallel_outputs))
logger.info("Forward outputs are finite")
if REF_NAME:
with torch.no_grad():
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
del ref_model, ref_outputs
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
with torch.no_grad():
parallel_outputs = model.forward(test_inputs).logits
assert torch.all(torch.isfinite(parallel_outputs))
logger.info("Forward outputs are finite")
with torch.inference_mode():
embs = model.transformer.word_embeddings(test_inputs)
embs = model.transformer.word_embeddings_layernorm(embs)
recurrent_outputs = []
@ -60,5 +36,20 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
dictionary = model.transformer.word_embeddings.weight.t()
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
recurrent_outputs = (recurrent_outputs @ dictionary).float()
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")
del model, recurrent_outputs
if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
del ref_model, ref_outputs, dummy_mask
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
assert False

@ -0,0 +1,43 @@
import pytest
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from test_utils import *
from src import RemoteSequential
from src.client.remote_model import DistributedBloomConfig
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@pytest.mark.forked
def test_remote_sequential():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
grad_proj = torch.randn(1, 5, config.hidden_size)
sequential = RemoteSequential(config, dht)
full_outputs = sequential(test_inputs)
(full_outputs * grad_proj).sum().backward()
assert test_inputs.grad is not None
full_grad = test_inputs.grad.clone()
test_inputs.grad.data.zero_()
first_half = sequential[: config.n_layer // 2]
second_half = sequential[config.n_layer // 2 :]
assert len(first_half) + len(second_half) == len(sequential)
assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
for m in sequential, first_half, second_half:
assert isinstance(repr(m), str)
hidden = first_half(test_inputs)
assert isinstance(hidden, torch.Tensor)
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad)

@ -0,0 +1,13 @@
import os
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
MODEL_NAME = os.environ.get("MODEL_NAME")
if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME")
Loading…
Cancel
Save