switch to hivemind-master

inference_chain
justheuristic 2 years ago
parent 57f4e0a899
commit 5a15c13ca7

@ -18,7 +18,7 @@ conda activate bloom-demo
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install bitsandbytes-cuda113==0.26.0
pip install https://github.com/learning-at-home/hivemind/archive/bc2cccfdb0d7c905a12ef6c3ad052a1250af9878.zip
pip install https://github.com/learning-at-home/hivemind/archive/master.zip
pip install https://github.com/huggingface/transformers/archive/224bde91caff4ccfd12277ab5e9bf97c61e22ee9.zip
```

@ -8,8 +8,8 @@ from hivemind.moe.expert_uid import ExpertUID
from hivemind.moe.server.dht_handler import _get_experts
from hivemind.p2p import StubBase, P2P
from hivemind.proto.runtime_pb2 import ExpertInfo
from hivemind.dht import DHTExpiration, DHT
from hivemind.utils import MPFuture
from hivemind.dht import DHT
from hivemind.utils import MPFuture, DHTExpiration
from src.server.handler import TransformerConnectionHandler

@ -2,14 +2,13 @@
from typing import Tuple, Sequence
import torch
from hivemind.moe.server.expert_backend import ExpertBackend
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPool
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
class TransformerBlockBackend(ExpertBackend):
class TransformerBackend(ModuleBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
@ -21,12 +20,12 @@ class TransformerBlockBackend(ExpertBackend):
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
return inputs[0] * 2
with self.memory_cache.use_cache(attention_cache_handle) as cache:
cache[...] += 1
return inputs[0] + cache
def get_pools(self) -> Sequence[TaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool

@ -1,19 +1,21 @@
from typing import AsyncIterator, Dict
import torch
from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor, ModuleBackend
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import anext
from src.server.backend import TransformerBlockBackend
from src.server.backend import TransformerBackend
class TransformerConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
super().__init__(dht, module_backends)
for module_backend in module_backends.values():
assert isinstance(module_backend, TransformerBackend)
async def rpc_inference(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@ -21,7 +23,7 @@ class TransformerConnectionHandler(ConnectionHandler):
request = await anext(requests)
backend = self.experts[request.uid]
assert isinstance(backend, TransformerBlockBackend)
assert isinstance(backend, TransformerBackend)
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):

@ -14,7 +14,7 @@ import multiprocessing as mp
from src import DistributedBloomConfig
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
from src.server.backend import TransformerBlockBackend
from src.server.backend import TransformerBackend
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
@ -27,7 +27,7 @@ class Server(threading.Thread):
def __init__(
self,
dht: DHT,
module_backends: Dict[str, TransformerBlockBackend],
module_backends: Dict[str, TransformerBackend],
*,
device: torch.device,
num_connection_handlers: int = 8,
@ -118,7 +118,7 @@ class Server(threading.Thread):
for param in block.parameters():
param.requires_grad = False
blocks[module_uid] = TransformerBlockBackend(
blocks[module_uid] = TransformerBackend(
module_uid,
block,
memory_cache=memory_cache,

Loading…
Cancel
Save