refactor, add swarm info

This commit is contained in:
Aleksandr Borzunov 2022-06-29 14:26:47 +03:00 committed by justheuristic
parent 331591c915
commit b78d713347
5 changed files with 171 additions and 25 deletions

View File

@ -65,7 +65,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward()
# test inference, one block
with layer3.begin_inference_session() as sess:
with layer3.inference_session() as sess:
for i in range(10):
res = sess.step(torch.ones(1, 1, 4096))
```

View File

@ -11,13 +11,17 @@ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils import anext, nested_flatten
from hivemind.utils import anext, nested_flatten, use_hivemind_log_handler, get_logger
from src.data_structures import RemoteModuleInfo
from src.dht_utils import ModuleUID
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteTransformerBlock(RemoteExpert):
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
@ -34,11 +38,15 @@ class RemoteTransformerBlock(RemoteExpert):
assert v is None, f"Extra keyword arguments are not yet supported (got {k} = {v})"
return super().forward(inputs)
def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
def inference_session(self) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
def begin_inference_session(self):
logger.warning("beging_inference_session was renamed to just inference_session")
return self.inference_session()
class RemoteTransformerBlockInferenceSession:
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""

View File

@ -1,14 +1,19 @@
from __future__ import annotations
import dataclasses
import logging
import threading
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Tuple, NamedTuple, List, Sequence
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from hivemind import DHT, get_logger, use_hivemind_log_handler, PeerID
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from torch import nn
from src import DistributedBloomConfig
from src.data_structures import UID_DELIMITER, RemoteModuleInfo
from src.data_structures import UID_DELIMITER, RemoteModuleInfo, ModuleUID
from src.dht_utils import _create_remote_modules_from_infos, _get_remote_module_infos
use_hivemind_log_handler("in_root_logger")
@ -32,27 +37,13 @@ class RemoteSequential(nn.Sequential):
super().__init__()
self.config = config
self.dht = dht
self.prefix = prefix
self.max_retries = max_retries
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
self.prefix = prefix
self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
logger.debug(f"Remote block uids: {self.block_uids}")
self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(
dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
return_future=False,
)
)
self.max_retries = max_retries
assert len(self.block_infos) == len(self.block_uids)
for uid, info in zip(self.block_uids, self.block_infos):
assert isinstance(info, (type(None), RemoteModuleInfo)), f"Unexpected dht entry for {uid}: {info}"
assert info is not None, f"Found no active peers for block {uid}"
assert isinstance(info.peer_ids, set), f"expected peer_ids to be a set, got {info.peer_ids}"
assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
assert len(info.peer_ids) > 0, f"Found no active peers for block {uid}"
self.remote_model_info = RemoteModelInfo(dht, self.block_uids)
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@ -80,3 +71,150 @@ class RemoteSequential(nn.Sequential):
def __iter__(self):
for block_index in range(self.config.n_layer):
yield self[block_index]
def inference_session(self) -> RemoteSequentialInferenceSession:
self.remote_model_info.update_()
return RemoteExpertWorker.run_coroutine(RemoteSequentialInferenceSession._create(self))
Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
@dataclasses.dataclass(frozen=False, init=False)
class RemoteModelInfo:
"""Stores meta-information about which peers host which blocks - and prepare to form sessions"""
dht: DHT
block_uids: Tuple[ModuleUID, ...]
block_infos: List[Optional[RemoteModuleInfo], ...]
spans_by_priority: List[Span] # sorted from best to worst
spans_containing_block: Tuple[List[Span], ...]
lock_changes: threading.Lock
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
self.dht = dht
self.block_uids = block_uids
self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
self.spans_by_priority = []
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
self.lock_changes = threading.Lock()
self.update_()
for uid, info in zip(self.block_uids, self.block_infos):
assert info is not None, f"Found no remote peers for block {uid}"
assert self.spans_by_priority and self.spans_containing_block
def update_(self):
with self.lock_changes:
self.update_block_infos_()
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
def update_block_infos_(self):
new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
return_future=False)
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.warning(f"Found no block info for block {uid}")
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
if not info.peer_ids:
logger.warning(f"Found no active peers for block {uid}")
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
if not isinstance(info.peer_ids, set):
logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
self.block_infos[block_index] = info
@staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
for peer_id in info.peer_ids:
if peer_id not in active_spans:
active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
for peer_id in list(active_spans.keys()):
if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block
#
# class RemoteSequentialInferenceSession:
# """An interface to a multi-step *inference* session for a sequence of remote modules"""
#
# def __init__(self, block):
# self.closed = False
#
# @classmethod
# async def _create(cls, remote_sequential: RemoteSequential, **kwargs) -> RemoteSequentialInferenceSession:
# """Create a new session for a sequence of modules. This code is meant to be run inside RemoteExpertWorker"""
#
# remote_sequential.
# return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
#
# def step(self, new_hidden_states: torch.Tensor):
# """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
# if self.closed:
# raise Exception("Session is closed, cannot perform step")
# # serialize inputs and put them into the queue
# inputs = (new_hidden_states,)
# outputs_serialized = RemoteExpertWorker.run_coroutine(
# self._step(
# runtime_pb2.ExpertRequest(
# uid=self.uid,
# tensors=[
# serialize_torch_tensor(tensor, proto.compression)
# for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
# ],
# )
# )
# )
# outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
# assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
# return outputs[0]
#
# async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
# """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
# await self._inputs_queue.put(inputs_serialized)
# return await anext(self._outputs_stream)
#
# def close(self):
# """Finish a given inference session, close the underlying connection"""
# if self._outputs_stream is None:
# return # already closed
# RemoteExpertWorker.run_coroutine(self._aclose_stream())
# self._outputs_stream = self._inputs_queue = None
# self.closed = True
#
# async def _aclose_stream(self):
# """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
# if self._outputs_stream is None:
# return # already closed
# await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
# try:
# await anext(self._outputs_stream)
# except StopAsyncIteration:
# pass
#
# def __del__(self):
# self.close()
#
# def __enter__(self):
# assert not self.closed
# return self
#
# def __exit__(self, *exc_details):
# self.close()

View File

@ -32,7 +32,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
(outputs_forward,) = remote_block(inputs)
outputs_inference = []
with remote_block.begin_inference_session() as sess:
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)

View File

@ -39,7 +39,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
inputs = torch.randn(1, 8, 4096)
outputs_inference = []
with remote_block.begin_inference_session() as sess:
with remote_block.inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)