You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/client.py

153 lines
7.4 KiB
Python

from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch.autograd.function import once_differentiable
import hivemind
from load_balancer import LoadBalancer
from hivemind.moe.client.expert import DUMMY, expert_forward
from hivemind.proto import runtime_pb2
from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
logger = get_logger(__name__)
MAX_NODES = 99999
class BalancedRemoteExpert(nn.Module):
"""
A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
ToDo docstring, similar to hivemind.RemoteExpert
"""
def __init__(
self,
*,
dht: hivemind.DHT,
uid_prefix: str,
grid_size: Tuple[int, ...] = (1, MAX_NODES),
forward_timeout: Optional[float] = None,
backward_timeout: Optional[float] = None,
update_period: float = 30.0,
backward_task_size_multiplier: float = 2.5,
**kwargs,
):
super().__init__()
if uid_prefix.endswith(".0."):
logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
self.backward_task_size_multiplier = backward_task_size_multiplier
self.expert_balancer = LoadBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
self._expert_info = None # expert['info'] from one of experts in the grid
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
"""
Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
:param args: input tensors that will be passed to each expert after input, batch-first
:param kwargs: extra keyword tensors that will be passed to each expert, batch-first
:returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
"""
assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
if self._expert_info is None:
raise NotImplementedError()
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (args, kwargs)
if not nested_compare(forward_inputs, self.info["forward_schema"]):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
flat_inputs = list(nested_flatten(forward_inputs))
forward_task_size = flat_inputs[0].shape[0]
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
self.expert_balancer,
self.info,
self.forward_timeout,
self.backward_timeout,
forward_task_size,
forward_task_size * self.backward_task_size_multiplier,
*flat_inputs)
return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
@property
def info(self):
while self._expert_info is None:
try:
with self.expert_balancer.use_another_expert(1) as chosen_expert:
self._expert_info = chosen_expert.info
except BaseException as e:
logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
return self._expert_info
class _BalancedRemoteModuleCall(torch.autograd.Function):
"""Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
@staticmethod
def forward(
ctx,
dummy: torch.Tensor,
expert_balancer: LoadBalancer,
info: Dict[str, Any],
forward_timeout: float,
backward_timeout: float,
forward_task_size: float,
backward_task_size: float,
*inputs: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
ctx.expert_balancer, ctx.info = expert_balancer, info
ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
inputs = tuple(tensor.cpu().detach() for tensor in inputs)
ctx.save_for_backward(*inputs)
serialized_tensors = [
serialize_torch_tensor(inp, proto.compression)
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
]
while True:
try:
with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(
chosen_expert.uid, inputs, serialized_tensors, chosen_expert.stub))
break
except BaseException as e:
logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
return tuple(deserialized_outputs)
@staticmethod
@once_differentiable
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
raise NotImplementedError("Backward is not yet implemented in this example")
# grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
# inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
# backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
# serialized_tensors = [
# serialize_torch_tensor(tensor, proto.compression)
# for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
# ]
# while True:
# try:
# with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
# backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
# grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
# break
# except BaseException as e:
# logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
# deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
# return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)