diff --git a/src/server/handler.py b/src/server/handler.py index a4aaf3f..798d4a5 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -7,6 +7,9 @@ from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import anext +from hivemind.utils.streaming import split_for_streaming +from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE +from hivemind.utils import as_aiter from src.data_structures import CHAIN_DELIMITER, ModuleUID from src.server.backend import MAX_LENGTH, TransformerBackend @@ -67,6 +70,140 @@ class TransformerConnectionHandler(ConnectionHandler): finally: print("CLOSED RPC_INFERENCE") + async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: + # Parse request and prepare backends + hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + requested_uids = self._check_header(request) + requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + + # Run a chain of requested backends + for backend in requested_backends: + assert isinstance(hidden_states, (list, tuple)) + assert ( + len(hidden_states) == 1 and hidden_states[0].ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + hidden_states = await backend.forward_pool.submit_task(*hidden_states) + + # Serialize the overall output and respond + assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 + return runtime_pb2.ExpertResponse(tensors=[ + serialize_torch_tensor(result, proto.compression, allow_inplace=True) + for result, proto in zip( + hidden_states, nested_flatten(requested_backends[-1].outputs_schema) + ) + ]) + + async def rpc_forward_stream( + self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext + ) -> AsyncIterator[runtime_pb2.ExpertRequest]: + # Parse requests and prepare backends + uids_header, hidden_states = await self._gather_inputs(requests, context) + requested_uids = self._check_header_str(uids_header) + requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + + # Run a chain of requested backends + for backend in requested_backends: + assert isinstance(hidden_states, (list, tuple)) + assert ( + len(hidden_states) == 1 and hidden_states[0].ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + hidden_states = await backend.forward_pool.submit_task(*hidden_states) + + # Serialize the overall output + assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 + serialized_output = [ + serialize_torch_tensor(result, proto.compression, allow_inplace=True) + for result, proto in zip( + hidden_states, nested_flatten(requested_backends[-1].outputs_schema) + ) + ] + + # Split the serialized_output for streaming and respond + output_split = [ + part + for tensor in serialized_output + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) + ] + async for part in as_aiter(*output_split): + yield runtime_pb2.ExpertResponse(tensors=[part]) + + async def rpc_backward( + self, request: runtime_pb2.ExpertRequest, context: P2PContext + ) -> runtime_pb2.ExpertResponse: + # Parse requests and prepare backends + inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + requested_uids = self._check_header(request) + requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + + # Run a forward chain to collect intermediate inputs + # Note that we do not forward for the last module since we do not need its output + inter_inputs = [inputs] + for backend in requested_backends[:-1]: + assert (inputs.ndim == 3 + ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + inputs = await backend.forward_pool.submit_task(inputs) + assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1) + inputs = inputs[0] + inter_inputs.append(inputs) + + # Run a chain of requested backends + for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): + inputs_and_grads = [inp, grads] + grads = await backend.backward_pool.submit_task(*inputs_and_grads) + assert (isinstance(grads, (list, tuple)) and len(grads) == 1) + grads = grads[0] + + # Serialize the overall grad_input and respond + return runtime_pb2.ExpertResponse(tensors=[ + serialize_torch_tensor(result, proto.compression, allow_inplace=True) + for result, proto in zip( + [grads], nested_flatten(requested_backends[0].grad_inputs_schema) + ) + ]) + + async def rpc_backward_stream( + self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext + ) -> AsyncIterator[runtime_pb2.ExpertResponse]: + uids_header, inputs_and_grads = await self._gather_inputs(requests, context) + inputs, grads = inputs_and_grads + requested_uids = self._check_header_str(uids_header) + requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + + # Run a forward chain to collect intermediate inputs + # Note that we do not forward for the last module since we do not need its outputs + inter_inputs = [inputs] + for backend in requested_backends[:-1]: + assert (inputs.ndim == 3 + ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + inputs = await backend.forward_pool.submit_task(inputs) + assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1) + inputs = inputs[0] + inter_inputs.append(inputs) + + # Run a backward chain for requested backends + for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): + inputs_and_grads = [inp, grads] + grads = await backend.backward_pool.submit_task(*inputs_and_grads) + assert (isinstance(grads, (list, tuple)) and len(grads) == 1) + grads = grads[0] + + # Serialize the overall grad_inputs + serialized_grad_inputs = [ + serialize_torch_tensor(result, proto.compression, allow_inplace=True) + for result, proto in zip( + [grads], nested_flatten(requested_backends[0].grad_inputs_schema) + ) + ] + # Split the serialized_grad_inputs for streaming and respond + output_split = [ + part + for tensor in serialized_grad_inputs + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) + ] + + async for part in as_aiter(*output_split): + yield runtime_pb2.ExpertResponse(tensors=[part]) + def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]: """Check that the first request to rpc_inference is valid""" uids = (request.uid or "").split(CHAIN_DELIMITER) @@ -77,6 +214,16 @@ class TransformerConnectionHandler(ConnectionHandler): raise RuntimeError(f"Remote peer does not serve {uid}") return tuple(uids) + def _check_header_str(self, header) -> Sequence[ModuleUID]: + """Check that the first request to rpc_inference is valid""" + uids = (header or "").split(CHAIN_DELIMITER) + if not uids: + raise RuntimeError("User did not provide any uids") + for uid in uids: + if uid not in self.module_backends: + raise RuntimeError(f"Remote peer does not serve {uid}") + return tuple(uids) + @contextlib.asynccontextmanager async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]: """Allocate memory caches for each transformer block, return cache handles""" diff --git a/tests/test_chained_forward_backward.py b/tests/test_chained_forward_backward.py new file mode 100644 index 0000000..c4835ce --- /dev/null +++ b/tests/test_chained_forward_backward.py @@ -0,0 +1,59 @@ +###### +# Warning:torch this test is a work in progress. It will be modified soon. +# - if you want more stable tests, see test_block_exact_match +# - if you want to figure out chained inference, ask yozh + +import os + +import hivemind +import torch +from hivemind.moe.expert_uid import ExpertInfo + +from src.bloom.from_pretrained import load_pretrained_block +from src.client.remote_block import RemoteTransformerBlock +from src.dht_utils import get_remote_module + +INITIAL_PEERS = os.environ.get("INITIAL_PEERS") +if not INITIAL_PEERS: + raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") +INITIAL_PEERS = INITIAL_PEERS.split() + + +BLOCK_UID = os.environ.get("BLOCK_UID") +if not BLOCK_UID: + raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested") + +REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3") + + +# seq_length > 128: rpc_forward_stream & rpc_backward_stream +# seq_length <= 128: rpc_forward & rpc_backward +def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + remote_block, = get_remote_module(dht, BLOCK_UID) + assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT" + assert isinstance(remote_block, RemoteTransformerBlock) + + _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info + remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id) + + ref_blocks = [ + load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32), + load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32), + load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32), + ] + inputs = torch.randn(1, seq_length, 4096, requires_grad=True) + outputs_rpc = remote_block.forward(inputs)[0] + outputs_rpc.sum().backward() + grads_rpc = inputs.grad + + inputs.grad = None + hidden_states = inputs + for ref_block in ref_blocks: + hidden_states = ref_block.forward(hidden_states)[0] + outputs_ref = hidden_states + outputs_ref.sum().backward() + grads_ref = inputs.grad + + assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward) + assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)