|
|
|
@ -19,30 +19,30 @@ from petals.data_structures import ModuleUID, RPCInfo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _forward_unary(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
|
|
|
|
|
timeout=config.request_timeout,
|
|
|
|
|
)
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _backward_unary(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
|
|
|
|
|
timeout=config.request_timeout,
|
|
|
|
|
)
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _forward_stream(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
parts = (
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
|
|
|
|
|
for tensor in serialized_tensors
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
)
|
|
|
|
@ -52,10 +52,10 @@ async def _forward_stream(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _backward_stream(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
parts = (
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
|
|
|
|
|
for tensor in serialized_tensors
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
)
|
|
|
|
@ -68,31 +68,19 @@ async def run_remote_forward(
|
|
|
|
|
uid: ModuleUID,
|
|
|
|
|
stub: StubBase,
|
|
|
|
|
rpc_info: RPCInfo,
|
|
|
|
|
*inputs: torch.Tensor,
|
|
|
|
|
*forward_inputs: torch.Tensor,
|
|
|
|
|
config: ClientConfig,
|
|
|
|
|
metadata: Optional[bytes] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
|
|
"""
|
|
|
|
|
Serializes input tensors and calls "rpc_forward" on a remote server.
|
|
|
|
|
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
|
|
|
|
|
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
|
|
|
# detach to avoid pickling the computation graph
|
|
|
|
|
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
|
|
|
|
|
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
|
|
|
|
|
|
|
|
|
|
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
|
|
|
|
|
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
|
|
|
|
|
args_schema, kwargs_schema = rpc_info["forward_schema"]
|
|
|
|
|
compression = args_schema[0].compression
|
|
|
|
|
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
|
|
|
|
|
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
|
|
|
|
|
# TODO: create more explicit way to check servers schema and client's structure
|
|
|
|
|
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
|
|
|
|
|
|
|
|
|
|
# Asynchronous serialization
|
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
|
serialized_tensors = await asyncio.gather(
|
|
|
|
@ -106,7 +94,7 @@ async def run_remote_forward(
|
|
|
|
|
size = sum(t.element_size() * t.nelement() for t in inputs)
|
|
|
|
|
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
|
|
|
|
|
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
|
|
|
|
|
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
|
|
|
|
|
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
|
|
|
|
|
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|