Merge 26d4cd855d
into 30f522d1a0
commit
088ba3f74b
@ -0,0 +1,254 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "21e78d30",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"from typing import Sequence, Tuple, Iterable, List\n",
|
||||
"from tqdm.auto import trange\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import hivemind\n",
|
||||
"import petals\n",
|
||||
"\n",
|
||||
"from petals.server.handler import TransformerConnectionHandler, split_for_streaming\n",
|
||||
"from petals.client import RemoteSequenceManager, ClientConfig\n",
|
||||
"from petals.client.remote_forward_backward import DEFAULT_MAX_MSG_SIZE, iter_as_aiter, aiter_with_timeout, deserialize_tensor_stream\n",
|
||||
"from petals.data_structures import ModuleUID, PeerID, CHAIN_DELIMITER, UID_DELIMITER\n",
|
||||
"from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs\n",
|
||||
"\n",
|
||||
"from hivemind.compression import serialize_torch_tensor\n",
|
||||
"from hivemind.utils import MSGPackSerializer, nested_flatten\n",
|
||||
"from hivemind.proto import runtime_pb2\n",
|
||||
"\n",
|
||||
"_END_OF_STREAM_KEY = \"_EOS\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def pack_as_expert_requests(uid, flat_tensors, codecs, metadata):\n",
|
||||
" # Asynchronous serialization\n",
|
||||
" loop = asyncio.get_running_loop()\n",
|
||||
" serialized_tensors = await asyncio.gather(\n",
|
||||
" *(\n",
|
||||
" loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)\n",
|
||||
" for tensor, compression in zip(flat_tensors, codecs)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" parts = [\n",
|
||||
" tensor_part for tensor in serialized_tensors\n",
|
||||
" for tensor_part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)\n",
|
||||
" ]\n",
|
||||
" if len(parts) > 1:\n",
|
||||
" serialized_metadata = MSGPackSerializer.dumps(metadata)\n",
|
||||
" serialized_metadata_last_piece = MSGPackSerializer.dumps(dict(metadata, **{_END_OF_STREAM_KEY: True}))\n",
|
||||
" \n",
|
||||
" return [\n",
|
||||
" runtime_pb2.ExpertRequest(\n",
|
||||
" uid=uid, tensors=[tensor_part], \n",
|
||||
" metadata=serialized_metadata if i != len(parts) - 1 else serialized_metadata_last_piece)\n",
|
||||
" for i, tensor_part in enumerate(parts)\n",
|
||||
" ]\n",
|
||||
" \n",
|
||||
"async def run_remote_forward_backward(\n",
|
||||
" sequence_manager: RemoteSequenceManager,\n",
|
||||
" peer_id: PeerID,\n",
|
||||
" span_uids: Sequence[ModuleUID],\n",
|
||||
" *args: torch.Tensor,\n",
|
||||
" **kwargs: torch.Tensor,\n",
|
||||
") -> Tuple[torch.Tensor, ...]:\n",
|
||||
" \"\"\"\n",
|
||||
" Serializes input tensors and calls \"rpc_forward_backward\" on a remote server.\n",
|
||||
" Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198\n",
|
||||
" but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.\n",
|
||||
" \"\"\"\n",
|
||||
" merged_uid = CHAIN_DELIMITER.join(span_uids)\n",
|
||||
" stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)\n",
|
||||
" flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)\n",
|
||||
" metadata = sequence_manager.get_request_metadata(\"rpc_forward\", args_structure, uids=span_uids, *args, peer_id=peer_id, **kwargs) #TODO fix metadata api\n",
|
||||
" #codecs = sequence_manager.get_compression_codecs(peer_id, \"rpc_forward\", span_uids, *args, **kwargs)\n",
|
||||
" codecs = [runtime_pb2.CompressionType.NONE for _ in args] #TODO replace with proper compression\n",
|
||||
" flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)\n",
|
||||
" args_structure = metadata.setdefault(\"args_structure\", args_structure)\n",
|
||||
" if codecs is None:\n",
|
||||
" codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)\n",
|
||||
" else:\n",
|
||||
" codecs = list(nested_flatten(codecs))\n",
|
||||
" assert len(codecs) == len(flat_tensors), f\"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # call RPC on remote server\n",
|
||||
" size = sum(t.element_size() * t.nelement() for t in flat_tensors)\n",
|
||||
" # Hotfix: we use \"// 2\" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR\n",
|
||||
" \n",
|
||||
" ### HERE BEGINS INLINED REQUEST SENDER \n",
|
||||
" # used to look like this:\n",
|
||||
" # output_tensors = await _run_forward_part(\n",
|
||||
" # merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata\n",
|
||||
" # )\n",
|
||||
" config = sequence_manager.config\n",
|
||||
" assert _END_OF_STREAM_KEY not in metadata\n",
|
||||
" forward_requests = await pack_as_expert_requests(merged_uid, flat_tensors, codecs, metadata)\n",
|
||||
" backward_codecs = [runtime_pb2.CompressionType.NONE] #TODO replace with proper compression\n",
|
||||
" fake_grad_outputs = torch.randn_like(flat_tensors[0])\n",
|
||||
" _, backward_args_structure = pack_args_kwargs(args[0], fake_grad_outputs, *args[1:], **kwargs)\n",
|
||||
" backward_metadata = dict(metadata, args_structure=backward_args_structure)\n",
|
||||
" \n",
|
||||
" grad_requests = await pack_as_expert_requests(merged_uid, (fake_grad_outputs,), backward_codecs, backward_metadata)\n",
|
||||
" \n",
|
||||
" received_outputs = asyncio.Event()\n",
|
||||
"\n",
|
||||
" async def iterate_inputs():\n",
|
||||
" for request in forward_requests:\n",
|
||||
" yield request\n",
|
||||
" print(\"WAITING FOR OUTPUTS\")\n",
|
||||
" await received_outputs.wait()\n",
|
||||
" print(\"RECEIVED OUTPUTS - SENDING GRADS\")\n",
|
||||
" for request in grad_requests:\n",
|
||||
" yield request\n",
|
||||
" print(\"SENT GRADS\")\n",
|
||||
"\n",
|
||||
" async def _wrap_input_stream(stream):\n",
|
||||
" async for expert_request in stream:\n",
|
||||
" yield expert_request\n",
|
||||
" if not expert_request.metadata:\n",
|
||||
" continue #TODO write more generally\n",
|
||||
" metadata = MSGPackSerializer.loads(expert_request.metadata)\n",
|
||||
" print(metadata)\n",
|
||||
" if metadata.get(_END_OF_STREAM_KEY):\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" print(\"CALLING stub.rpc_forward_stream on serialized inputs\", iterate_inputs())\n",
|
||||
" outputs_stream = await asyncio.wait_for(stub.rpc_forward_backward_stream(iterate_inputs()), config.connect_timeout)\n",
|
||||
" outputs_stream = aiter_with_timeout(outputs_stream, config.request_timeout)\n",
|
||||
" \n",
|
||||
" output_hidden_states = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))\n",
|
||||
" received_outputs.set()\n",
|
||||
"\n",
|
||||
" grad_inputs = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))\n",
|
||||
" print(\"RECEIVED GRAD INPUTS\")\n",
|
||||
" #TODOreturn output_hidden_states, grads\n",
|
||||
"\n",
|
||||
" ####\n",
|
||||
" \n",
|
||||
" # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591\n",
|
||||
" requires_grad = any(tensor.requires_grad for tensor in flat_tensors)\n",
|
||||
" output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_hidden_states]\n",
|
||||
" return output_tensors, grad_inputs\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "1c47c89a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Mar 17 18:37:25.661 [\u001b[1m\u001b[34mINFO\u001b[0m] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1\n",
|
||||
"Mar 17 18:37:25.661 [\u001b[1m\u001b[34mINFO\u001b[0m] Using DHT prefix: TinyLLama-v0-hf\n",
|
||||
"100%|██████████| 1/1 [00:00<00:00, 26.19it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CALLING stub.rpc_forward_stream on serialized inputs <async_generator object run_remote_forward_backward.<locals>.iterate_inputs at 0x75eb8d134d60>\n",
|
||||
"WAITING FOR OUTPUTS\n",
|
||||
"{'_EOS': True}\n",
|
||||
"RECEIVED OUTPUTS - SENDING GRADS\n",
|
||||
"SENT GRADS\n",
|
||||
"RECEIVED GRAD INPUTS\n",
|
||||
"outputs: tensor([[[-0.0835, 0.3027, 0.2217, ..., 1.1719 ...\n",
|
||||
"It works!\n",
|
||||
"shutting down\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"INITIAL_PEERS = ['/ip4/127.0.0.1/tcp/1337/p2p/QmRTdR9XmTHNXKiwtqRJ4i7tNofnmFrxkufBefguZUyXej']\n",
|
||||
"peer_id_string = INITIAL_PEERS[0].split(\"/\")[-1]\n",
|
||||
"model_name = \"Maykeye/TinyLLama-v0\"\n",
|
||||
"\n",
|
||||
"model_config = petals.DistributedLlamaConfig.from_pretrained(model_name)\n",
|
||||
"block_uids = [\n",
|
||||
" f\"{model_config.dht_prefix}{UID_DELIMITER}{i}\"\n",
|
||||
" for i in range(model_config.num_hidden_layers)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"block_in_use = block_uids[0:2]\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" dht = hivemind.DHT(start=True, client_mode=True, initial_peers=INITIAL_PEERS)\n",
|
||||
" sequence_manager = petals.RemoteSequenceManager(model_config, block_uids, dht=dht)\n",
|
||||
" sequence_manager.rpc_info\n",
|
||||
" p2p = await dht.replicate_p2p()\n",
|
||||
" \n",
|
||||
" dummy_inputs = [\n",
|
||||
" torch.rand(1, 128, model_config.hidden_size, dtype=model_config.torch_dtype),\n",
|
||||
" torch.empty(0, dtype=model_config.torch_dtype),\n",
|
||||
" ]\n",
|
||||
" peer_id = hivemind.PeerID.from_base58(peer_id_string)\n",
|
||||
" for i in trange(1):\n",
|
||||
" (outputs,), grads = await run_remote_forward_backward(sequence_manager, peer_id, block_in_use, *dummy_inputs)\n",
|
||||
" print('outputs:', repr(outputs)[:50], '...')\n",
|
||||
" print(\"It works!\")\n",
|
||||
"\n",
|
||||
"finally:\n",
|
||||
" print(\"shutting down\")\n",
|
||||
" await p2p.shutdown()\n",
|
||||
" dht.shutdown() # it is okay to remove this clause, but you will be summoning a horde of daemons as you debug"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f72fac2c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5392ba6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Binary file not shown.
Loading…
Reference in New Issue