add chained rpc_forward & rpc_backward

rpc
Dmitry Baranchuk 2 years ago
parent 0b5a68983f
commit 4cb986f680

@ -7,6 +7,9 @@ from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import anext
from hivemind.utils.streaming import split_for_streaming
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.utils import as_aiter
from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import MAX_LENGTH, TransformerBackend
@ -67,6 +70,140 @@ class TransformerConnectionHandler(ConnectionHandler):
finally:
print("CLOSED RPC_INFERENCE")
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse request and prepare backends
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_header(request)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a chain of requested backends
for backend in requested_backends:
assert isinstance(hidden_states, (list, tuple))
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
# Serialize the overall output and respond
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
return runtime_pb2.ExpertResponse(tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
)
])
async def rpc_forward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
# Parse requests and prepare backends
uids_header, hidden_states = await self._gather_inputs(requests, context)
requested_uids = self._check_header_str(uids_header)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a chain of requested backends
for backend in requested_backends:
assert isinstance(hidden_states, (list, tuple))
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
# Serialize the overall output
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
serialized_output = [
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
)
]
# Split the serialized_output for streaming and respond
output_split = [
part
for tensor in serialized_output
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
async def rpc_backward(
self, request: runtime_pb2.ExpertRequest, context: P2PContext
) -> runtime_pb2.ExpertResponse:
# Parse requests and prepare backends
inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_header(request)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = [inputs]
for backend in requested_backends[:-1]:
assert (inputs.ndim == 3
), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
inputs = await backend.forward_pool.submit_task(inputs)
assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
inputs = inputs[0]
inter_inputs.append(inputs)
# Run a chain of requested backends
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
inputs_and_grads = [inp, grads]
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
grads = grads[0]
# Serialize the overall grad_input and respond
return runtime_pb2.ExpertResponse(tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
[grads], nested_flatten(requested_backends[0].grad_inputs_schema)
)
])
async def rpc_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
inputs, grads = inputs_and_grads
requested_uids = self._check_header_str(uids_header)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its outputs
inter_inputs = [inputs]
for backend in requested_backends[:-1]:
assert (inputs.ndim == 3
), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
inputs = await backend.forward_pool.submit_task(inputs)
assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
inputs = inputs[0]
inter_inputs.append(inputs)
# Run a backward chain for requested backends
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
inputs_and_grads = [inp, grads]
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
grads = grads[0]
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
[grads], nested_flatten(requested_backends[0].grad_inputs_schema)
)
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
part
for tensor in serialized_grad_inputs
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
"""Check that the first request to rpc_inference is valid"""
uids = (request.uid or "").split(CHAIN_DELIMITER)
@ -77,6 +214,16 @@ class TransformerConnectionHandler(ConnectionHandler):
raise RuntimeError(f"Remote peer does not serve {uid}")
return tuple(uids)
def _check_header_str(self, header) -> Sequence[ModuleUID]:
"""Check that the first request to rpc_inference is valid"""
uids = (header or "").split(CHAIN_DELIMITER)
if not uids:
raise RuntimeError("User did not provide any uids")
for uid in uids:
if uid not in self.module_backends:
raise RuntimeError(f"Remote peer does not serve {uid}")
return tuple(uids)
@contextlib.asynccontextmanager
async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
"""Allocate memory caches for each transformer block, return cache handles"""

@ -0,0 +1,59 @@
######
# Warning:torch this test is a work in progress. It will be modified soon.
# - if you want more stable tests, see test_block_exact_match
# - if you want to figure out chained inference, ask yozh
import os
import hivemind
import torch
from hivemind.moe.expert_uid import ExpertInfo
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
BLOCK_UID = os.environ.get("BLOCK_UID")
if not BLOCK_UID:
raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
# seq_length > 128: rpc_forward_stream & rpc_backward_stream
# seq_length <= 128: rpc_forward & rpc_backward
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
remote_block, = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
]
inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
outputs_rpc = remote_block.forward(inputs)[0]
outputs_rpc.sum().backward()
grads_rpc = inputs.grad
inputs.grad = None
hidden_states = inputs
for ref_block in ref_blocks:
hidden_states = ref_block.forward(hidden_states)[0]
outputs_ref = hidden_states
outputs_ref.sum().backward()
grads_ref = inputs.grad
assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
Loading…
Cancel
Save