|
|
|
@ -1,16 +1,46 @@
|
|
|
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
|
|
from queue import Empty
|
|
|
|
|
from typing import Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import use_hivemind_log_handler
|
|
|
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
|
|
|
from hivemind.moe.server.task_pool import TaskPool
|
|
|
|
|
from hivemind.utils import InvalidStateError, get_logger
|
|
|
|
|
|
|
|
|
|
from src.bloom.from_pretrained import BloomBlock
|
|
|
|
|
from src.server.cache import MemoryCache
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
MAX_LENGTH = 2048
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InferenceTaskPool(TaskPool):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
|
|
|
|
|
|
|
|
|
|
def iterate_minibatches(self, *args, **kwargs):
|
|
|
|
|
"""Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
logger.debug(f"{self.name} getting next task")
|
|
|
|
|
task = self.tasks.get(timeout=self.timeout)
|
|
|
|
|
except Empty:
|
|
|
|
|
logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if task.future.set_running_or_notify_cancel():
|
|
|
|
|
yield [task]
|
|
|
|
|
except InvalidStateError as e:
|
|
|
|
|
logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBackend(ModuleBackend):
|
|
|
|
|
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
|
|
|
|
|
|
|
|
@ -23,7 +53,9 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
|
for name, buf in self.module.named_buffers():
|
|
|
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
|
|
|
|
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
|
|
|
|
self.inference_pool = InferenceTaskPool(
|
|
|
|
|
self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|