You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

153 lines
7.4 KiB

from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from import RemoteExpertWorker
from torch.autograd.function import once_differentiable
import hivemind
from load_balancer import LoadBalancer
from import DUMMY, expert_forward
from hivemind.proto import runtime_pb2
from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
logger = get_logger(__name__)
MAX_NODES = 99999
class BalancedRemoteExpert(nn.Module):
A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
ToDo docstring, similar to hivemind.RemoteExpert
def __init__(
dht: hivemind.DHT,
uid_prefix: str,
grid_size: Tuple[int, ...] = (1, MAX_NODES),
forward_timeout: Optional[float] = None,
backward_timeout: Optional[float] = None,
update_period: float = 30.0,
backward_task_size_multiplier: float = 2.5,
if uid_prefix.endswith(".0."):
logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
self.backward_task_size_multiplier = backward_task_size_multiplier
self.expert_balancer = LoadBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
self._expert_info = None # expert['info'] from one of experts in the grid
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
:param args: input tensors that will be passed to each expert after input, batch-first
:param kwargs: extra keyword tensors that will be passed to each expert, batch-first
:returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
assert len(kwargs) == len(["keyword_names"]), f"Keyword args should be {['keyword_names']}"
kwargs = {key: kwargs[key] for key in["keyword_names"]}
if self._expert_info is None:
raise NotImplementedError()
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (args, kwargs)
if not nested_compare(forward_inputs,["forward_schema"]):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
flat_inputs = list(nested_flatten(forward_inputs))
forward_task_size = flat_inputs[0].shape[0]
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
forward_task_size * self.backward_task_size_multiplier,
return nested_pack(flat_outputs,["outputs_schema"])
def info(self):
while self._expert_info is None:
with self.expert_balancer.use_another_expert(1) as chosen_expert:
self._expert_info =
except BaseException as e:
logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
return self._expert_info
class _BalancedRemoteModuleCall(torch.autograd.Function):
"""Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
def forward(
dummy: torch.Tensor,
expert_balancer: LoadBalancer,
info: Dict[str, Any],
forward_timeout: float,
backward_timeout: float,
forward_task_size: float,
backward_task_size: float,
*inputs: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
ctx.expert_balancer, = expert_balancer, info
ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
inputs = tuple(tensor.cpu().detach() for tensor in inputs)
serialized_tensors = [
serialize_torch_tensor(inp, proto.compression)
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
while True:
with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(
chosen_expert.uid, inputs, serialized_tensors, chosen_expert.stub))
except BaseException as e:
logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
return tuple(deserialized_outputs)
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
raise NotImplementedError("Backward is not yet implemented in this example")
# grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
# inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
# backward_schema = tuple(nested_flatten((["forward_schema"],["outputs_schema"])))
# serialized_tensors = [
# serialize_torch_tensor(tensor, proto.compression)
# for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
# ]
# while True:
# try:
# with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
# backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
# grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
# break
# except BaseException as e:
# logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
# deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
# return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)