diff --git a/README.md b/README.md index c63f832..be95eb1 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ conda activate bloom-demo conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install bitsandbytes-cuda113==0.26.0 -pip install https://github.com/learning-at-home/hivemind/archive/bc2cccfdb0d7c905a12ef6c3ad052a1250af9878.zip +pip install https://github.com/learning-at-home/hivemind/archive/master.zip pip install https://github.com/huggingface/transformers/archive/224bde91caff4ccfd12277ab5e9bf97c61e22ee9.zip ``` diff --git a/src/client/remote_block.py b/src/client/remote_block.py index cd88b03..8fa3dc0 100644 --- a/src/client/remote_block.py +++ b/src/client/remote_block.py @@ -8,8 +8,8 @@ from hivemind.moe.expert_uid import ExpertUID from hivemind.moe.server.dht_handler import _get_experts from hivemind.p2p import StubBase, P2P from hivemind.proto.runtime_pb2 import ExpertInfo -from hivemind.dht import DHTExpiration, DHT -from hivemind.utils import MPFuture +from hivemind.dht import DHT +from hivemind.utils import MPFuture, DHTExpiration from src.server.handler import TransformerConnectionHandler diff --git a/src/server/backend.py b/src/server/backend.py index 3b6a9fe..b765255 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -2,14 +2,13 @@ from typing import Tuple, Sequence import torch -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.task_pool import TaskPool -from src.bloom.block import BloomBlock from src.server.cache import MemoryCache -class TransformerBlockBackend(ExpertBackend): +class TransformerBackend(ModuleBackend): """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward""" def __init__(self, *args, memory_cache: MemoryCache, **kwargs): @@ -21,12 +20,12 @@ class TransformerBlockBackend(ExpertBackend): for name, buf in self.module.named_buffers(): assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" - self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference") def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]: - with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values): - return inputs[0] * 2 + with self.memory_cache.use_cache(attention_cache_handle) as cache: + cache[...] += 1 + return inputs[0] + cache def get_pools(self) -> Sequence[TaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool diff --git a/src/server/handler.py b/src/server/handler.py index b849e84..365631a 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -1,19 +1,21 @@ from typing import AsyncIterator, Dict import torch -from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor +from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor, ModuleBackend from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import anext -from src.server.backend import TransformerBlockBackend +from src.server.backend import TransformerBackend class TransformerConnectionHandler(ConnectionHandler): """Handles three request types: forward, backward and forward-incremental (inference)""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]): + super().__init__(dht, module_backends) + for module_backend in module_backends.values(): + assert isinstance(module_backend, TransformerBackend) async def rpc_inference( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext @@ -21,7 +23,7 @@ class TransformerConnectionHandler(ConnectionHandler): request = await anext(requests) backend = self.experts[request.uid] - assert isinstance(backend, TransformerBlockBackend) + assert isinstance(backend, TransformerBackend) inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors] async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))): diff --git a/src/server/server.py b/src/server/server.py index c262ca1..405b614 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -14,7 +14,7 @@ import multiprocessing as mp from src import DistributedBloomConfig from src.bloom.block import BloomBlock from src.server.cache import MemoryCache -from src.server.backend import TransformerBlockBackend +from src.server.backend import TransformerBackend from src.server.handler import TransformerConnectionHandler use_hivemind_log_handler("in_root_logger") @@ -27,7 +27,7 @@ class Server(threading.Thread): def __init__( self, dht: DHT, - module_backends: Dict[str, TransformerBlockBackend], + module_backends: Dict[str, TransformerBackend], *, device: torch.device, num_connection_handlers: int = 8, @@ -118,7 +118,7 @@ class Server(threading.Thread): for param in block.parameters(): param.requires_grad = False - blocks[module_uid] = TransformerBlockBackend( + blocks[module_uid] = TransformerBackend( module_uid, block, memory_cache=memory_cache,