wip some more

This commit is contained in:
Your Name 2023-09-05 22:07:48 +03:00
parent e5c2d8eca4
commit 49474e5477
4 changed files with 83 additions and 73 deletions

View File

@ -4,7 +4,7 @@ import asyncio
import itertools
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple
from typing import AsyncIterator, List, Optional, Tuple, Sequence
import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
@ -34,7 +34,7 @@ class _ServerInferenceSession:
self,
config: ClientConfig,
span: RemoteSpanInfo,
uid: ModuleUID,
span_uids: Sequence[ModuleUID],
rpc_info: RPCInfo,
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
@ -43,8 +43,8 @@ class _ServerInferenceSession:
**metadata,
):
self.config = config
self.span, self.uid, self.rpc_info = span, uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info
self.num_blocks = len(span_uids)
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.session_id = str(uuid.uuid4())
@ -62,18 +62,19 @@ class _ServerInferenceSession:
config: ClientConfig,
p2p: P2P,
span: RemoteSpanInfo,
uid: ModuleUID,
span_uids: Sequence[RemoteSpanInfo],
rpc_info: RPCInfo,
**metadata,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
# TODO YOZH you don't need rpc info here
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.connect_timeout,
)
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@ -142,6 +143,7 @@ class _ServerInferenceSession:
request_metadata["args_structure"] = args_structure
# TODO YOZH FIX THIS BEFORE THE END OF THIS PR
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
@ -155,7 +157,7 @@ class _ServerInferenceSession:
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
uid=CHAIN_DELIMITER.join(self.span_uids),
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, inference_schema)
@ -244,8 +246,8 @@ class InferenceSession:
server_sessions = []
try:
for span in chosen_spans:
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
span_uids = self._sequence_manager.block_uids[span.start : span.end]
metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids)
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
self._sequence_manager.config,

View File

@ -2,50 +2,49 @@
Utility functions that call RPC forward or backward on a single remote server
"""
import asyncio
from typing import Iterable, List, Optional, Sequence, Tuple
from typing import Iterable, List, Sequence, Tuple
import torch
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor, PeerID
from hivemind import PeerID, nested_flatten, serialize_torch_tensor
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
from hivemind.p2p import StubBase
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.serializer import MSGPackSerializer
from hivemind.utils.streaming import split_for_streaming
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals import RemoteSequenceManager
from petals.client.config import ClientConfig
from petals.data_structures import ModuleUID, RPCInfo, CHAIN_DELIMITER
from petals.client.routing import RemoteSequenceManager
from petals.data_structures import CHAIN_DELIMITER, ModuleUID
from petals.server.handler import TransformerConnectionHandler
from petals.utils.packaging import pack_args_kwargs
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
@ -55,10 +54,10 @@ async def _forward_stream(
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
@ -81,35 +80,39 @@ async def run_remote_forward(
"""
merged_uid = CHAIN_DELIMITER.join(span_uids)
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
flat_inputs, args_structure = pack_args_kwargs(*args, **kwargs)
metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs)
compressions = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
if compressions is None:
compressions = [runtime_pb2.CompressionType.NONE] * len(flat_inputs)
compressions = list(nested_flatten(compressions))
assert len(compressions) == len(flat_inputs), f"got {len(flat_inputs)} tensors but {len(compressions)} codecs"
inputs = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_inputs)
codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
args_structure = metadata.setdefault("args_structure", args_structure)
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
for tensor, compression in zip(inputs, compressions)
for tensor, compression in zip(flat_tensors, codecs)
)
)
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
return await forward_fn(merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata)
return await forward_fn(
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
)
async def run_remote_backward(
sequence_manager: RemoteSequenceManager,
peer_id: PeerID,
span_uids: Sequence[ModuleUID],
stub: StubBase,
grad_outputs: Sequence[torch.Tensor],
*args: torch.Tensor,
**kwargs: torch.Tensor,
@ -119,23 +122,32 @@ async def run_remote_backward(
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
flat_tensors, args_structure = pack_args_kwargs(
[grad.cpu() for grad in grad_outputs], args, kwargs
)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
)
merged_uid = CHAIN_DELIMITER.join(span_uids)
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
metadata = sequence_manager.get_request_metadata(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
flat_tensors, args_structure = pack_args_kwargs(grad_outputs, *args, **kwargs)
flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
args_structure = metadata.setdefault("args_structure", args_structure)
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, compression)
for tensor, proto in zip(flat_inputs_and_grad_outputs, backward_schema)
loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
for tensor, compression in zip(flat_tensors, codecs)
)
)
size = sum(t.element_size() * t.nelement() for t in flat_inputs_and_grad_outputs)
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
return await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
return await backward_fn(
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
)

View File

@ -474,7 +474,9 @@ class RemoteSequenceManager:
return 0
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
def get_request_metadata(self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Dict[str, Any]]:
def get_request_metadata(
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
) -> Optional[Dict[str, Any]]:
"""
:param peer_id: remote server's PeerID
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
@ -488,7 +490,8 @@ class RemoteSequenceManager:
)
def get_compression_codecs(
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
"""
:param peer_id: remote server's PeerID
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"

View File

@ -4,19 +4,16 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
import asyncio
import itertools
from collections import deque
from typing import List, Optional, Sequence, Tuple, Dict, Any
from typing import Any, Dict, List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.data_structures import RemoteSpanInfo
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -24,12 +21,12 @@ MAX_TOKENS_IN_BATCH = 1024
async def sequential_forward(
sequence_manager: RemoteSequenceManager,
inputs: torch.Tensor,
prompts: torch.Tensor,
sequence_manager: RemoteSequenceManager,
start_index: int = 0,
end_index: Optional[int] = None,
block_kwargs: Sequence[Dict[str, Any]] = (),
*block_kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
@ -45,13 +42,6 @@ async def sequential_forward(
:param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block
"""
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert len(block_kwargs) in (
0,
1,
end_index - start_index,
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
@ -59,6 +49,9 @@ async def sequential_forward(
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert len(block_kwargs) in (0, 1, end_index - start_index), \
f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert is_dummy(prompts) or len(prompts) == len(
sequence_manager.block_uids
) # should be n_layers - 1 but add extra prompts for convenience
@ -87,13 +80,13 @@ async def sequential_forward(
sequence_manager.block_uids[span.start : span.end],
inputs,
prompts[span.start : span.end],
*block_kwargs[span.start : span.end]
*block_kwargs[span.start : span.end],
)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
# Save intermediate inputs and subsequ_peerences if the forward is already done for them
# Save intermediate inputs and subsequences if the forward is already done for them
intermediate_inputs.append(inputs)
done_sequences.append(span)
@ -118,11 +111,12 @@ async def sequential_forward(
async def sequential_backward(
sequence_manager: RemoteSequenceManager,
forward_sequences: List[RemoteSpanInfo],
grad_outputs: Sequence[torch.Tensor],
intermediate_inputs: List[torch.Tensor],
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
*block_kwargs: Dict[str, Any],
) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
@ -148,7 +142,7 @@ async def sequential_backward(
try:
if attempt_no >= 1:
_, backup_inputs, backup_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end
)
assert len(backup_inputs) == len(backup_sequences)
assert backup_sequences[0].start == span.start
@ -159,14 +153,13 @@ async def sequential_backward(
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
grad_outputs, *span_grad_prompts = await run_remote_backward(
sequence_manager,
sequence_manager.block_uids[span.start: span.end],
span_uids,
grad_outputs, inputs,
span.peer_id,
sequence_manager.block_uids[span.start : span.end],
grad_outputs,
*inputs,
*block_kwargs[span.start : span.end],
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)
@ -198,7 +191,7 @@ async def _gather_forward(input_batches, prompt_batches, sequence_manager):
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
return await asyncio.gather(
*[
sequential_forward(input_batch, prompt_batch, sequence_manager)
sequential_forward(sequence_manager, input_batch, prompt_batch)
for input_batch, prompt_batch in zip(input_batches, prompt_batches)
]
)
@ -210,7 +203,7 @@ async def _gather_backward(
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
return await asyncio.gather(
*[
sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
sequential_backward(sequence_manager, spans, (grad_output,), input_batch, prompt_batch)
for grad_output, input_batch, prompt_batch, spans in zip(
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
)