diff --git a/src/server/backend.py b/src/server/backend.py index 4b1d411..3b6a9fe 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -1,8 +1,7 @@ """Code for serving bloom blocks via hivemind-server""" -from typing import Tuple +from typing import Tuple, Sequence import torch -from hivemind import BatchTensorDescriptor from hivemind.moe.server.expert_backend import ExpertBackend from hivemind.moe.server.task_pool import TaskPool @@ -14,7 +13,7 @@ class TransformerBlockBackend(ExpertBackend): """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward""" def __init__(self, *args, memory_cache: MemoryCache, **kwargs): - super().__init__(*args, **kwargs) # to bypass super.__init__ + super().__init__(*args, **kwargs) self.memory_cache = memory_cache for name, param in self.module.named_parameters(): @@ -22,6 +21,12 @@ class TransformerBlockBackend(ExpertBackend): for name, buf in self.module.named_buffers(): assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" - def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]: + + self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference") + + def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]: with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values): return inputs[0] * 2 + + def get_pools(self) -> Sequence[TaskPool]: + return self.forward_pool, self.backward_pool, self.inference_pool diff --git a/src/server/handler.py b/src/server/handler.py index 10a1395..b849e84 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -15,24 +15,16 @@ class TransformerConnectionHandler(ConnectionHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def rpc_forward_incremental( + async def rpc_inference( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertRequest]: request = await anext(requests) - expert = self.experts[request.uid] - assert isinstance(expert, TransformerBlockBackend) + backend = self.experts[request.uid] + assert isinstance(backend, TransformerBlockBackend) inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors] - async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))): - outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema) + async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))): + outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema) - return runtime_pb2.ExpertResponse(tensors=outputs) - - - # note: you may use self.experts[uid].memory_cache! - # encode expert_uid as @model_name[starting_layer:finishing_layer] - # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat - # - receive and maintain a handle for attention cache here - - raise NotImplementedError() + yield runtime_pb2.ExpertResponse(tensors=outputs)