black-isort

inference_chain
justheuristic 2 years ago
parent a28ea0aa6f
commit 3b16d6ffdb

@ -10,14 +10,7 @@ from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
# TODO
# BloomBackend serves a single layer
# - ensure that parameters do not require grad!
# - ensure that TaskPool for inference is NOT batched
# - ensure that optimizer/scheduler is not created
class BloomBlockBackend(ExpertBackend):
class TransformerBlockBackend(ExpertBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
@ -31,4 +24,4 @@ class BloomBlockBackend(ExpertBackend):
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
raise NotImplementedError("TODO")
return inputs[0] * 2

@ -1,10 +1,12 @@
from typing import AsyncIterator, Dict
from hivemind import P2PContext, DHT
import torch
from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import anext
from src.bloom.block import BloomBlock
from src.server.backend import TransformerBlockBackend
class TransformerConnectionHandler(ConnectionHandler):
@ -16,6 +18,18 @@ class TransformerConnectionHandler(ConnectionHandler):
async def rpc_forward_incremental(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
request = await anext(requests)
expert = self.experts[request.uid]
assert isinstance(expert, TransformerBlockBackend)
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
return runtime_pb2.ExpertResponse(tensors=outputs)
# note: you may use self.experts[uid].memory_cache!
# encode expert_uid as @model_name[starting_layer:finishing_layer]
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat

@ -14,7 +14,7 @@ import multiprocessing as mp
from src import DistributedBloomConfig
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
from src.server.backend import BloomBlockBackend
from src.server.backend import TransformerBlockBackend
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
@ -27,7 +27,7 @@ class Server(threading.Thread):
def __init__(
self,
dht: DHT,
module_backends: Dict[str, BloomBlockBackend],
module_backends: Dict[str, TransformerBlockBackend],
*,
device: torch.device,
num_connection_handlers: int = 8,
@ -118,7 +118,7 @@ class Server(threading.Thread):
for param in block.parameters():
param.requires_grad = False
blocks[module_uid] = BloomBlockBackend(
blocks[module_uid] = TransformerBlockBackend(
module_uid,
block,
memory_cache=memory_cache,

Loading…
Cancel
Save