You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
7.4 KiB
Python
153 lines
7.4 KiB
Python
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
import hivemind
|
|
from load_balancer import LoadBalancer
|
|
from hivemind.moe.client.expert import DUMMY, expert_forward
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
|
|
from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
MAX_NODES = 99999
|
|
|
|
|
|
class BalancedRemoteExpert(nn.Module):
|
|
"""
|
|
A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
|
|
ToDo docstring, similar to hivemind.RemoteExpert
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dht: hivemind.DHT,
|
|
uid_prefix: str,
|
|
grid_size: Tuple[int, ...] = (1, MAX_NODES),
|
|
forward_timeout: Optional[float] = None,
|
|
backward_timeout: Optional[float] = None,
|
|
update_period: float = 30.0,
|
|
backward_task_size_multiplier: float = 2.5,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
if uid_prefix.endswith(".0."):
|
|
logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
|
|
assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
|
|
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
self.backward_task_size_multiplier = backward_task_size_multiplier
|
|
self.expert_balancer = LoadBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
|
|
self._expert_info = None # expert['info'] from one of experts in the grid
|
|
|
|
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
"""
|
|
Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
|
|
|
|
:param args: input tensors that will be passed to each expert after input, batch-first
|
|
:param kwargs: extra keyword tensors that will be passed to each expert, batch-first
|
|
:returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
|
|
"""
|
|
assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
|
|
kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
|
|
|
|
if self._expert_info is None:
|
|
raise NotImplementedError()
|
|
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
|
|
|
|
forward_inputs = (args, kwargs)
|
|
|
|
if not nested_compare(forward_inputs, self.info["forward_schema"]):
|
|
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
|
|
|
flat_inputs = list(nested_flatten(forward_inputs))
|
|
forward_task_size = flat_inputs[0].shape[0]
|
|
|
|
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
|
|
flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
|
|
self.expert_balancer,
|
|
self.info,
|
|
self.forward_timeout,
|
|
self.backward_timeout,
|
|
forward_task_size,
|
|
forward_task_size * self.backward_task_size_multiplier,
|
|
*flat_inputs)
|
|
|
|
return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
|
|
|
|
@property
|
|
def info(self):
|
|
while self._expert_info is None:
|
|
try:
|
|
with self.expert_balancer.use_another_expert(1) as chosen_expert:
|
|
self._expert_info = chosen_expert.info
|
|
except BaseException as e:
|
|
logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
|
|
return self._expert_info
|
|
|
|
|
|
class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
"""Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
dummy: torch.Tensor,
|
|
expert_balancer: LoadBalancer,
|
|
info: Dict[str, Any],
|
|
forward_timeout: float,
|
|
backward_timeout: float,
|
|
forward_task_size: float,
|
|
backward_task_size: float,
|
|
*inputs: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
# detach to avoid pickling the computation graph
|
|
ctx.expert_balancer, ctx.info = expert_balancer, info
|
|
ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
|
|
ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
|
|
inputs = tuple(tensor.cpu().detach() for tensor in inputs)
|
|
ctx.save_for_backward(*inputs)
|
|
|
|
serialized_tensors = [
|
|
serialize_torch_tensor(inp, proto.compression)
|
|
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
]
|
|
while True:
|
|
try:
|
|
with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(
|
|
chosen_expert.uid, inputs, serialized_tensors, chosen_expert.stub))
|
|
break
|
|
except BaseException as e:
|
|
logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
|
return tuple(deserialized_outputs)
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
|
|
raise NotImplementedError("Backward is not yet implemented in this example")
|
|
# grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
# inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
# backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
# serialized_tensors = [
|
|
# serialize_torch_tensor(tensor, proto.compression)
|
|
# for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
# ]
|
|
# while True:
|
|
# try:
|
|
# with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
|
|
# backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
# grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
|
|
# break
|
|
# except BaseException as e:
|
|
# logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
|
|
# deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
# return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
|