diff --git a/examples/workbench_call_rpc_directly.ipynb b/examples/workbench_call_rpc_directly.ipynb new file mode 100644 index 0000000..2793fda --- /dev/null +++ b/examples/workbench_call_rpc_directly.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "21e78d30", + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from typing import Sequence, Tuple, Iterable, List\n", + "from tqdm.auto import trange\n", + "\n", + "import torch\n", + "import hivemind\n", + "import petals\n", + "\n", + "from petals.server.handler import TransformerConnectionHandler, split_for_streaming\n", + "from petals.client import RemoteSequenceManager, ClientConfig\n", + "from petals.client.remote_forward_backward import DEFAULT_MAX_MSG_SIZE, iter_as_aiter, aiter_with_timeout, deserialize_tensor_stream\n", + "from petals.data_structures import ModuleUID, PeerID, CHAIN_DELIMITER, UID_DELIMITER\n", + "from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs\n", + "\n", + "from hivemind.compression import serialize_torch_tensor\n", + "from hivemind.utils import MSGPackSerializer, nested_flatten\n", + "from hivemind.proto import runtime_pb2\n", + "\n", + "_END_OF_STREAM_KEY = \"_EOS\"\n", + "\n", + "\n", + "async def pack_as_expert_requests(uid, flat_tensors, codecs, metadata):\n", + " # Asynchronous serialization\n", + " loop = asyncio.get_running_loop()\n", + " serialized_tensors = await asyncio.gather(\n", + " *(\n", + " loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)\n", + " for tensor, compression in zip(flat_tensors, codecs)\n", + " )\n", + " )\n", + "\n", + " parts = [\n", + " tensor_part for tensor in serialized_tensors\n", + " for tensor_part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)\n", + " ]\n", + " if len(parts) > 1:\n", + " serialized_metadata = MSGPackSerializer.dumps(metadata)\n", + " serialized_metadata_last_piece = MSGPackSerializer.dumps(dict(metadata, **{_END_OF_STREAM_KEY: True}))\n", + " \n", + " return [\n", + " runtime_pb2.ExpertRequest(\n", + " uid=uid, tensors=[tensor_part], \n", + " metadata=serialized_metadata if i != len(parts) - 1 else serialized_metadata_last_piece)\n", + " for i, tensor_part in enumerate(parts)\n", + " ]\n", + " \n", + "async def run_remote_forward_backward(\n", + " sequence_manager: RemoteSequenceManager,\n", + " peer_id: PeerID,\n", + " span_uids: Sequence[ModuleUID],\n", + " *args: torch.Tensor,\n", + " **kwargs: torch.Tensor,\n", + ") -> Tuple[torch.Tensor, ...]:\n", + " \"\"\"\n", + " Serializes input tensors and calls \"rpc_forward_backward\" on a remote server.\n", + " Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198\n", + " but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.\n", + " \"\"\"\n", + " merged_uid = CHAIN_DELIMITER.join(span_uids)\n", + " stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)\n", + " flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)\n", + " metadata = sequence_manager.get_request_metadata(\"rpc_forward\", args_structure, uids=span_uids, *args, peer_id=peer_id, **kwargs) #TODO fix metadata api\n", + " #codecs = sequence_manager.get_compression_codecs(peer_id, \"rpc_forward\", span_uids, *args, **kwargs)\n", + " codecs = [runtime_pb2.CompressionType.NONE for _ in args] #TODO replace with proper compression\n", + " flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)\n", + " args_structure = metadata.setdefault(\"args_structure\", args_structure)\n", + " if codecs is None:\n", + " codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)\n", + " else:\n", + " codecs = list(nested_flatten(codecs))\n", + " assert len(codecs) == len(flat_tensors), f\"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs\"\n", + "\n", + "\n", + " # call RPC on remote server\n", + " size = sum(t.element_size() * t.nelement() for t in flat_tensors)\n", + " # Hotfix: we use \"// 2\" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR\n", + " \n", + " ### HERE BEGINS INLINED REQUEST SENDER \n", + " # used to look like this:\n", + " # output_tensors = await _run_forward_part(\n", + " # merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata\n", + " # )\n", + " config = sequence_manager.config\n", + " assert _END_OF_STREAM_KEY not in metadata\n", + " forward_requests = await pack_as_expert_requests(merged_uid, flat_tensors, codecs, metadata)\n", + " backward_codecs = [runtime_pb2.CompressionType.NONE] #TODO replace with proper compression\n", + " fake_grad_outputs = torch.randn_like(flat_tensors[0])\n", + " _, backward_args_structure = pack_args_kwargs(args[0], fake_grad_outputs, *args[1:], **kwargs)\n", + " backward_metadata = dict(metadata, args_structure=backward_args_structure)\n", + " \n", + " grad_requests = await pack_as_expert_requests(merged_uid, (fake_grad_outputs,), backward_codecs, backward_metadata)\n", + " \n", + " received_outputs = asyncio.Event()\n", + "\n", + " async def iterate_inputs():\n", + " for request in forward_requests:\n", + " yield request\n", + " print(\"WAITING FOR OUTPUTS\")\n", + " await received_outputs.wait()\n", + " print(\"RECEIVED OUTPUTS - SENDING GRADS\")\n", + " for request in grad_requests:\n", + " yield request\n", + " print(\"SENT GRADS\")\n", + "\n", + " async def _wrap_input_stream(stream):\n", + " async for expert_request in stream:\n", + " yield expert_request\n", + " if not expert_request.metadata:\n", + " continue #TODO write more generally\n", + " metadata = MSGPackSerializer.loads(expert_request.metadata)\n", + " print(metadata)\n", + " if metadata.get(_END_OF_STREAM_KEY):\n", + " break\n", + "\n", + " print(\"CALLING stub.rpc_forward_stream on serialized inputs\", iterate_inputs())\n", + " outputs_stream = await asyncio.wait_for(stub.rpc_forward_backward_stream(iterate_inputs()), config.connect_timeout)\n", + " outputs_stream = aiter_with_timeout(outputs_stream, config.request_timeout)\n", + " \n", + " output_hidden_states = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))\n", + " received_outputs.set()\n", + "\n", + " grad_inputs = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))\n", + " print(\"RECEIVED GRAD INPUTS\")\n", + " #TODOreturn output_hidden_states, grads\n", + "\n", + " ####\n", + " \n", + " # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591\n", + " requires_grad = any(tensor.requires_grad for tensor in flat_tensors)\n", + " output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_hidden_states]\n", + " return output_tensors, grad_inputs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1c47c89a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Mar 17 18:37:25.661 [\u001b[1m\u001b[34mINFO\u001b[0m] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1\n", + "Mar 17 18:37:25.661 [\u001b[1m\u001b[34mINFO\u001b[0m] Using DHT prefix: TinyLLama-v0-hf\n", + "100%|██████████| 1/1 [00:00<00:00, 26.19it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CALLING stub.rpc_forward_stream on serialized inputs .iterate_inputs at 0x75eb8d134d60>\n", + "WAITING FOR OUTPUTS\n", + "{'_EOS': True}\n", + "RECEIVED OUTPUTS - SENDING GRADS\n", + "SENT GRADS\n", + "RECEIVED GRAD INPUTS\n", + "outputs: tensor([[[-0.0835, 0.3027, 0.2217, ..., 1.1719 ...\n", + "It works!\n", + "shutting down\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "INITIAL_PEERS = ['/ip4/127.0.0.1/tcp/1337/p2p/QmRTdR9XmTHNXKiwtqRJ4i7tNofnmFrxkufBefguZUyXej']\n", + "peer_id_string = INITIAL_PEERS[0].split(\"/\")[-1]\n", + "model_name = \"Maykeye/TinyLLama-v0\"\n", + "\n", + "model_config = petals.DistributedLlamaConfig.from_pretrained(model_name)\n", + "block_uids = [\n", + " f\"{model_config.dht_prefix}{UID_DELIMITER}{i}\"\n", + " for i in range(model_config.num_hidden_layers)\n", + "]\n", + "\n", + "block_in_use = block_uids[0:2]\n", + "\n", + "try:\n", + " dht = hivemind.DHT(start=True, client_mode=True, initial_peers=INITIAL_PEERS)\n", + " sequence_manager = petals.RemoteSequenceManager(model_config, block_uids, dht=dht)\n", + " sequence_manager.rpc_info\n", + " p2p = await dht.replicate_p2p()\n", + " \n", + " dummy_inputs = [\n", + " torch.rand(1, 128, model_config.hidden_size, dtype=model_config.torch_dtype),\n", + " torch.empty(0, dtype=model_config.torch_dtype),\n", + " ]\n", + " peer_id = hivemind.PeerID.from_base58(peer_id_string)\n", + " for i in trange(1):\n", + " (outputs,), grads = await run_remote_forward_backward(sequence_manager, peer_id, block_in_use, *dummy_inputs)\n", + " print('outputs:', repr(outputs)[:50], '...')\n", + " print(\"It works!\")\n", + "\n", + "finally:\n", + " print(\"shutting down\")\n", + " await p2p.shutdown()\n", + " dht.shutdown() # it is okay to remove this clause, but you will be summoning a horde of daemons as you debug" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f72fac2c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5392ba6a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/server1.id b/server1.id new file mode 100644 index 0000000..9f982ec Binary files /dev/null and b/server1.id differ diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 2465656..ecd0b74 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -590,3 +590,79 @@ class TransformerConnectionHandler(ConnectionHandler): result.update(block_info) return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) + + @staticmethod + async def _read_until_eos(stream): + while True: + expert_request = await anext(stream) + yield expert_request + metadata = MSGPackSerializer.loads(expert_request.metadata) + print(metadata) + if metadata.get("_EOS"): + break + + async def rpc_forward_backward_stream( + self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext + ) -> AsyncIterator[runtime_pb2.ExpertRequest]: + async with timeout(self.request_timeout): + + # Parse requests and prepare backends + uid_str, flat_inputs, metadata = await self._gather_inputs(self._read_until_eos(requests), context) + requested_uids = self._check_uids(uid_str) + self._log_request("rpc_forward_stream", requested_uids, context) + + requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + active_adapter = self._get_active_adapter(metadata) + points = metadata.get("points", 0) + args_structure = metadata.get("args_structure") + assert isinstance( + points, (float, int) + ), f"rpc_forward_stream should have number of points as number or None, got {points}" + + print(f"{requested_backends=}, {active_adapter=}, {points=}, {args_structure=}") + + hidden_states = await run_rpc_forward( + *flat_inputs, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, + args_structure=args_structure, + ) + + for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata): + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): + print("EOS") + yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"_EOS": True})) + + + #### + new_uid_str, flat_extra_inputs, extra_metadata = await self._gather_inputs(self._read_until_eos(requests), context) + backward_args_structure = extra_metadata.get("args_structure") + assert len(flat_extra_inputs) == 1 + assert new_uid_str == uid_str + print("I solemnly swear to think about how to use extra_metadata for pushing when it comes to this") + grad_outputs, = flat_extra_inputs + + print("HERE!") + + print("FLAT INPUTS", flat_inputs) + print("GRAD OUTPUTS", grad_outputs) + print(backward_args_structure) + + grads = await run_rpc_backward( + flat_inputs[0], + grad_outputs, + *flat_inputs[1:], + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, + args_structure=backward_args_structure, + ) + + # Split the serialized_grad_inputs for streaming and respond + for tensor in self._serialize_grads(grads, requested_backends, metadata): + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): + print("SENDING GRADS:", part) + yield runtime_pb2.ExpertResponse(tensors=[part])