RemoteTransformerBlock

pull/9/head
justheuristic 2 years ago
parent 1cca611c9f
commit 3e9fd63a02

@ -0,0 +1 @@
from hivemind.moe.client import RemoteExpert

@ -0,0 +1,58 @@
from concurrent.futures import Future
from functools import partial
from typing import List, Optional, Union, Sequence
from hivemind.moe.client import RemoteExpert
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertUID
from hivemind.moe.server.dht_handler import _get_experts
from hivemind.p2p import StubBase, P2P
from hivemind.proto.runtime_pb2 import ExpertInfo
from hivemind.dht import DHTExpiration, DHT
from hivemind.utils import MPFuture
from src.server.handler import TransformerConnectionHandler
class RemoteTransformerBlock(RemoteExpert):
@property
def stub(self) -> StubBase:
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
def get_remote_module(
dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
"""
:param uids: find experts with these ids from across the DHT
:param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
:returns: a list of [RemoteTransformerBlock if found else None]
"""
assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
return create_remote_module(result, dht, return_future)
def create_remote_module(
infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
p2p = await dht.replicate_p2p()
return _create_remote_experts(await infos_future, p2p)
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
return _create_remote_experts(infos, p2p)
def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
experts: List[Optional[RemoteTransformerBlock]] = []
for info in infos:
if info is not None:
experts.append(RemoteTransformerBlock(info, p2p))
else:
experts.append(None)
return experts

@ -19,32 +19,15 @@ from src.server.cache import MemoryCache
class BloomBlockBackend(ExpertBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
object().__init__() # to bypass super.__init__
self.name, self.module = name, module
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
super().__init__(*args, **kwargs) # to bypass super.__init__
self.memory_cache = memory_cache
for name, param in module.named_parameters():
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
for name, buf in module.named_buffers():
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
self.args_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
self.kwargs_schema = {}
self.outputs_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward
self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward
self.grad_inputs_schema = self.forward_schema # outputs from backward have same shape as inputs for forward
self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
@property
def expert(self):
#TODO un-hardcode this naming from hivemind
return self.module
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
raise NotImplementedError("TODO")

@ -7,7 +7,7 @@ from hivemind.proto import runtime_pb2
from src.bloom.block import BloomBlock
class BloomConnectionHandler(ConnectionHandler):
class TransformerConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
def __init__(self, *args, **kwargs):

@ -15,7 +15,7 @@ from src import DistributedBloomConfig
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
from src.server.backend import BloomBlockBackend
from src.server.handler import BloomConnectionHandler
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -31,7 +31,7 @@ class Server(threading.Thread):
):
threading.Thread.__init__(self)
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
self.conn_handlers = [TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
@ -105,15 +105,17 @@ class Server(threading.Thread):
blocks = {}
for i in range(num_blocks):
module_uid = f"dummy_block.{i}"
HARDCODCED_LENGTH = 2048
block = BloomBlock(block_config, layer_number=i)
for param in block.parameters():
param.requires_grad = False
blocks[module_uid] = BloomBlockBackend(
module_uid,
BloomBlock(block_config, layer_number=i),
block,
memory_cache=memory_cache,
args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
kwargs_schema={},
outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
)
@ -121,7 +123,6 @@ class Server(threading.Thread):
return cls(
dht,
blocks,
cache_size_bytes=cache_size_bytes,
num_connection_handlers=num_handlers,
device=device,
stats_report_interval=stats_report_interval,

Loading…
Cancel
Save