|
|
|
@ -1,10 +1,12 @@
|
|
|
|
|
from typing import AsyncIterator, Dict
|
|
|
|
|
|
|
|
|
|
from hivemind import P2PContext, DHT
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
|
|
|
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
from hivemind.utils.asyncio import anext
|
|
|
|
|
|
|
|
|
|
from src.bloom.block import BloomBlock
|
|
|
|
|
from src.server.backend import TransformerBlockBackend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
@ -16,6 +18,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
async def rpc_forward_incremental(
|
|
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
|
|
|
|
|
|
request = await anext(requests)
|
|
|
|
|
expert = self.experts[request.uid]
|
|
|
|
|
assert isinstance(expert, TransformerBlockBackend)
|
|
|
|
|
|
|
|
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
|
|
|
|
|
outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
|
|
|
|
|
|
|
|
|
|
return runtime_pb2.ExpertResponse(tensors=outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# note: you may use self.experts[uid].memory_cache!
|
|
|
|
|
# encode expert_uid as @model_name[starting_layer:finishing_layer]
|
|
|
|
|
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
|
|
|
|
|