inference mode

This commit is contained in:
justheuristic 2022-06-14 14:51:06 +03:00
parent 3b16d6ffdb
commit e3a7d5af30
2 changed files with 15 additions and 18 deletions

View File

@ -1,8 +1,7 @@
"""Code for serving bloom blocks via hivemind-server"""
from typing import Tuple
from typing import Tuple, Sequence
import torch
from hivemind import BatchTensorDescriptor
from hivemind.moe.server.expert_backend import ExpertBackend
from hivemind.moe.server.task_pool import TaskPool
@ -14,7 +13,7 @@ class TransformerBlockBackend(ExpertBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
super().__init__(*args, **kwargs) # to bypass super.__init__
super().__init__(*args, **kwargs)
self.memory_cache = memory_cache
for name, param in self.module.named_parameters():
@ -22,6 +21,12 @@ class TransformerBlockBackend(ExpertBackend):
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
return inputs[0] * 2
def get_pools(self) -> Sequence[TaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool

View File

@ -15,24 +15,16 @@ class TransformerConnectionHandler(ConnectionHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def rpc_forward_incremental(
async def rpc_inference(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
request = await anext(requests)
expert = self.experts[request.uid]
assert isinstance(expert, TransformerBlockBackend)
backend = self.experts[request.uid]
assert isinstance(backend, TransformerBlockBackend)
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
return runtime_pb2.ExpertResponse(tensors=outputs)
# note: you may use self.experts[uid].memory_cache!
# encode expert_uid as @model_name[starting_layer:finishing_layer]
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
# - receive and maintain a handle for attention cache here
raise NotImplementedError()
yield runtime_pb2.ExpertResponse(tensors=outputs)