mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-16 06:12:50 +00:00
inference mode
This commit is contained in:
parent
3b16d6ffdb
commit
e3a7d5af30
@ -1,8 +1,7 @@
|
||||
"""Code for serving bloom blocks via hivemind-server"""
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Sequence
|
||||
|
||||
import torch
|
||||
from hivemind import BatchTensorDescriptor
|
||||
from hivemind.moe.server.expert_backend import ExpertBackend
|
||||
from hivemind.moe.server.task_pool import TaskPool
|
||||
|
||||
@ -14,7 +13,7 @@ class TransformerBlockBackend(ExpertBackend):
|
||||
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
||||
|
||||
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
|
||||
super().__init__(*args, **kwargs) # to bypass super.__init__
|
||||
super().__init__(*args, **kwargs)
|
||||
self.memory_cache = memory_cache
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
@ -22,6 +21,12 @@ class TransformerBlockBackend(ExpertBackend):
|
||||
for name, buf in self.module.named_buffers():
|
||||
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
||||
|
||||
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
|
||||
|
||||
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
||||
|
||||
def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
|
||||
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
|
||||
return inputs[0] * 2
|
||||
|
||||
def get_pools(self) -> Sequence[TaskPool]:
|
||||
return self.forward_pool, self.backward_pool, self.inference_pool
|
||||
|
@ -15,24 +15,16 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def rpc_forward_incremental(
|
||||
async def rpc_inference(
|
||||
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
||||
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
||||
|
||||
request = await anext(requests)
|
||||
expert = self.experts[request.uid]
|
||||
assert isinstance(expert, TransformerBlockBackend)
|
||||
backend = self.experts[request.uid]
|
||||
assert isinstance(backend, TransformerBlockBackend)
|
||||
|
||||
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
||||
async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
|
||||
outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
|
||||
async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
|
||||
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
|
||||
|
||||
return runtime_pb2.ExpertResponse(tensors=outputs)
|
||||
|
||||
|
||||
# note: you may use self.experts[uid].memory_cache!
|
||||
# encode expert_uid as @model_name[starting_layer:finishing_layer]
|
||||
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
|
||||
# - receive and maintain a handle for attention cache here
|
||||
|
||||
raise NotImplementedError()
|
||||
yield runtime_pb2.ExpertResponse(tensors=outputs)
|
||||
|
Loading…
Reference in New Issue
Block a user