add docstr

pull/467/head
Your Name 9 months ago
parent cc4fe17a99
commit 2e760319ab

@ -93,7 +93,7 @@ class _ServerInferenceSession:
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
if self.closed:

@ -19,30 +19,30 @@ from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
@ -52,10 +52,10 @@ async def _forward_stream(
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
@ -68,31 +68,19 @@ async def run_remote_forward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
*forward_inputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "rpc_forward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
args_schema, kwargs_schema = rpc_info["forward_schema"]
compression = args_schema[0].compression
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: create more explicit way to check servers schema and client's structure
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
@ -106,7 +94,7 @@ async def run_remote_forward(
size = sum(t.element_size() * t.nelement() for t in inputs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])

@ -4,7 +4,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
import asyncio
import itertools
from collections import deque
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Dict, Any
import torch
from hivemind import MSGPackSerializer
@ -29,14 +29,25 @@ async def sequential_forward(
sequence_manager: RemoteSequenceManager,
start_index: int = 0,
end_index: Optional[int] = None,
block_kwargs: Sequence[Dict[str, Any]] = (),
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
Performs chained forward for each subsequence of blocks on the path.
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
:param inputs: initial hidden states of shape [batch_size, sequence length, hidden_size]
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
:param sequence_manager: a running SequenceManager used to select remote servers and handle failures
:param start_index: run remote blocks starting from this index
:param end_index: run remote blocks up to (but not including) this index
:param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block
"""
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert len(block_kwargs) in (0, 1, end_index - start_index), \
f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
inputs_device = inputs.device
inputs_dtype = inputs.dtype
@ -68,7 +79,8 @@ async def sequential_forward(
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
flat_tensors, args_structure = pack_args_kwargs(
inputs, prompts[span.start : span.end], *block_kwargs[span.start: span.end])
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata(

Loading…
Cancel
Save