From 26ebbfe8f0c2ff9870116fe85549e272f5d78bf8 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 29 Aug 2023 07:49:27 +0400 Subject: [PATCH] Support macOS (#477) This PR makes both clients and servers work on macOS. Specifically, it: - Follows https://github.com/learning-at-home/hivemind/pull/586 to run a macOS-compatible `p2pd` binary (both x86-64 and ARM64 are supported) - Fixes forking issues and tests on macOS, Python 3.10+ - Introduces basic support for serving model blocks on Apple M1/M2 GPUs (torch.mps) - Increases max number of open files by default (it's not enough on Linux and is really small on macOS) --- .github/workflows/run-tests.yaml | 38 ++++++++++++------------- README.md | 16 ++++++++--- setup.cfg | 3 +- src/petals/__init__.py | 6 ++++ src/petals/cli/run_server.py | 20 +++++++++---- src/petals/server/reachability.py | 2 +- src/petals/server/server.py | 25 ++++++++++++++-- src/petals/server/throughput.py | 18 ++++++++---- tests/test_cache.py | 4 +-- tests/test_priority_pool.py | 47 +++++++++++++++++++------------ 10 files changed, 118 insertions(+), 61 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index f4a41f2..663b400 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -7,20 +7,21 @@ on: jobs: run-tests: - runs-on: ubuntu-latest strategy: matrix: include: - - { model: 'bigscience/bloom-560m', python-version: '3.8' } - - { model: 'bigscience/bloom-560m', python-version: '3.9' } - - { model: 'bigscience/bloom-560m', python-version: '3.10' } - - { model: 'bigscience/bloom-560m', python-version: '3.11' } - - { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' } - - { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' } + - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' } + - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' } + - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' } + - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' } + - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' } + - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' } fail-fast: false + runs-on: ${{ matrix.os }}-latest timeout-minutes: 15 steps: - name: Increase swap space + if: ${{ matrix.os == 'ubuntu' }} uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 @@ -47,12 +48,7 @@ jobs: export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" - # [Step 1] Watch free RAM (lack of RAM is a common issue in CI) - - bash -c 'while true; do free -h && sleep 30s; done' & - RAM_WATCH_PID=$! - - # [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) + # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) python -m petals.cli.run_dht \ --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log & @@ -61,7 +57,7 @@ jobs: export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs - sleep 5 # wait for DHT init + until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ --mean_balance_check_period 10 \ @@ -95,11 +91,15 @@ jobs: sleep 30 # wait for servers to eval throughput, download layers, and rebalance kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init - # [Step 3] Run PyTest + # [Step 2] Run PyTest + + # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 + export no_proxy=* + export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES pytest tests --durations=0 --durations-min=1.0 -v - # [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) + # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ --seq_len 3 @@ -110,9 +110,7 @@ jobs: python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm - # [Step 5] Clean up - - kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests + # [Step 4] Clean up - kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID + kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID echo "Done!" diff --git a/README.md b/README.md index ef3ca3f..6987489 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ python -m petals.cli.run_server petals-team/StableBeluga2 πŸͺŸ **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki. -πŸ‹ **Any OS + Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD): +πŸ‹ **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD): ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ @@ -59,12 +59,20 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2 ``` +🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands: + +```bash +brew install python +python3 -m pip install git+https://github.com/bigscience-workshop/petals +python3 -m petals.cli.run_server petals-team/StableBeluga2 +``` +

- πŸ“š  Learn more (using multiple GPUs, starting on boot, etc.) -            - πŸ’¬  Ask for help in Discord + πŸ“š  Learn more (how to use multiple GPUs, start the server on boot, etc.)

+πŸ’¬ **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)! + πŸ¦™ **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and πŸ€— [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an πŸ”‘ [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command. πŸ”’ **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). diff --git a/setup.cfg b/setup.cfg index 9d16b25..cf14434 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering Topic :: Scientific/Engineering :: Mathematics Topic :: Scientific/Engineering :: Artificial Intelligence @@ -39,7 +40,7 @@ install_requires = transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet - hivemind==1.1.9 + hivemind @ git+https://github.com/learning-at-home/hivemind tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 404a481..4e4a9d0 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,7 +1,13 @@ import os +import platform os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") +if platform.system() == "Darwin": + # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 + os.environ.setdefault("no_proxy", "*") + os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES") + import hivemind import transformers from packaging import version diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 720d64d..3728c16 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -1,8 +1,10 @@ import argparse +import logging import configargparse +import torch from hivemind.proto.runtime_pb2 import CompressionType -from hivemind.utils.limits import increase_file_limit +from hivemind.utils import limits from hivemind.utils.logging import get_logger from humanfriendly import parse_size @@ -127,9 +129,9 @@ def main(): group.add_argument('--new_swarm', action='store_true', help='Start a new private swarm (i.e., do not connect to any initial peers)') - parser.add_argument('--increase_file_limit', action='store_true', - help='On *nix, this will increase the max number of processes ' - 'a server can spawn before hitting "Too many open files"; Use at your own risk.') + parser.add_argument('--increase_file_limit', type=int, default=4096, + help='On *nix, increase the max number of files a server can open ' + 'before hitting "Too many open files" (set to zero to keep the system limit)') parser.add_argument('--stats_report_interval', type=int, required=False, help='Interval between two reports of batch processing performance statistics') @@ -185,8 +187,10 @@ def main(): args["startup_timeout"] = args.pop("daemon_startup_timeout") - if args.pop("increase_file_limit"): - increase_file_limit() + file_limit = args.pop("increase_file_limit") + if file_limit: + limits.logger.setLevel(logging.WARNING) + limits.increase_file_limit(file_limit, file_limit) compression_type = args.pop("compression").upper() compression = getattr(CompressionType, compression_type) @@ -207,6 +211,10 @@ def main(): validate_version() + if not torch.backends.openmp.is_available(): + # Necessary to prevent the server from freezing after forks + torch.set_num_threads(1) + server = Server( **args, host_maddrs=host_maddrs, diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index 497ddeb..179b4ba 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -140,7 +140,7 @@ class ReachabilityProtocol(ServicerBase): protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS) ready.set_result(True) - logger.info("Reachability service started") + logger.debug("Reachability service started") async with protocol.serve(common_p2p): await protocol._stop.wait() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 40865aa..562b56e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -9,7 +9,9 @@ import time from typing import Dict, List, Optional, Sequence, Union import hivemind +import psutil import torch +import torch.mps from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime @@ -154,13 +156,25 @@ class Server: self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" device = torch.device(device) if device.type == "cuda" and device.index is None: device = torch.device(device.type, index=0) self.device = device torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype]) + if device.type == "cpu" and torch_dtype == torch.float16: + raise ValueError( + f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16" + ) + if device.type == "mps" and torch_dtype == torch.bfloat16: + logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead") + torch_dtype = torch.float16 self.torch_dtype = torch_dtype if tensor_parallel_devices is None: @@ -253,13 +267,14 @@ class Server: self.stop = threading.Event() def _choose_num_blocks(self) -> int: - assert self.device.type == "cuda", ( + assert self.device.type in ("cuda", "mps"), ( "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. " "CPU-only servers in the public swarm are discouraged since they are much slower" ) num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 if num_devices > 1: + assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}" memory_per_device = tuple( torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices ) @@ -270,8 +285,10 @@ class Server: "Please launch individual servers on each GPU or set --num_blocks manually to " "override this exception." ) - else: + elif self.device.type == "cuda": total_memory = torch.cuda.get_device_properties(self.device).total_memory + else: + total_memory = psutil.virtual_memory().total gib = 1024**3 # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) @@ -373,6 +390,8 @@ class Server: f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, " f"{reserved_vram / gib:.1f} GiB reserved memory" ) + elif self.device.type == "mps": + torch.mps.empty_cache() def _choose_blocks(self) -> List[int]: if self.strict_block_indices is not None: diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 2806183..bf71f44 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Dict, Optional, Sequence, Union import torch +import torch.mps from hivemind.utils.logging import get_logger from transformers import PretrainedConfig @@ -207,14 +208,12 @@ def measure_compute_rps( elapsed = 0 dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time - if device.type == "cuda": - torch.cuda.synchronize(device) + synchronize(device) start_time = time.perf_counter() - for step in range(n_steps): + for _ in range(n_steps): _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) - if device.type == "cuda": - torch.cuda.synchronize(device) + synchronize(device) elapsed = time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed @@ -230,8 +229,15 @@ def measure_compute_rps( return device_rps +def synchronize(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize(device) + elif device.type == "mps": + torch.mps.synchronize() + + def get_device_name(device: torch.device) -> str: - return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU" + return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper() def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: diff --git a/tests/test_cache.py b/tests/test_cache.py index 6d40db1..65cbecc 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -118,7 +118,7 @@ async def test_cache_usage(): allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache await allocate_f_task - alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True) + alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True) alloc_process1.start() async def _allocate_bcde(): @@ -128,7 +128,7 @@ async def test_cache_usage(): allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED) - alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True) + alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True) alloc_process2.start() assert cache.current_size_bytes == 0 alloc_event.set() diff --git a/tests/test_priority_pool.py b/tests/test_priority_pool.py index 2623bb1..1a0b1da 100644 --- a/tests/test_priority_pool.py +++ b/tests/test_priority_pool.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import platform import time import pytest @@ -8,9 +9,30 @@ from hivemind.moe.server.runtime import Runtime from petals.server.task_pool import PrioritizedTaskPool +def _submit_tasks(runtime_ready, pools, results_valid): + runtime_ready.wait() + + futures = [] + futures.append(pools[0].submit_task(torch.tensor([0]), priority=1)) + futures.append(pools[0].submit_task(torch.tensor([1]), priority=1)) + time.sleep(0.01) + futures.append(pools[1].submit_task(torch.tensor([2]), priority=1)) + futures.append(pools[0].submit_task(torch.tensor([3]), priority=2)) + futures.append(pools[0].submit_task(torch.tensor([4]), priority=10)) + futures.append(pools[0].submit_task(torch.tensor([5]), priority=0)) + futures.append(pools[0].submit_task(torch.tensor([6]), priority=1)) + futures.append(pools[1].submit_task(torch.tensor([7]), priority=11)) + futures.append(pools[1].submit_task(torch.tensor([8]), priority=1)) + for i, f in enumerate(futures): + assert f.result()[0].item() == i**2 + results_valid.set() + + +@pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks") @pytest.mark.forked def test_priority_pools(): outputs_queue = mp.SimpleQueue() + runtime_ready = mp.Event() results_valid = mp.Event() def dummy_pool_func(x): @@ -31,27 +53,14 @@ def test_priority_pools(): PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1), ) + # Simulate requests coming from ConnectionHandlers + proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid)) + proc.start() + runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0) + runtime.ready = runtime_ready runtime.start() - def process_tasks(): - futures = [] - futures.append(pools[0].submit_task(torch.tensor([0]), priority=1)) - futures.append(pools[0].submit_task(torch.tensor([1]), priority=1)) - time.sleep(0.01) - futures.append(pools[1].submit_task(torch.tensor([2]), priority=1)) - futures.append(pools[0].submit_task(torch.tensor([3]), priority=2)) - futures.append(pools[0].submit_task(torch.tensor([4]), priority=10)) - futures.append(pools[0].submit_task(torch.tensor([5]), priority=0)) - futures.append(pools[0].submit_task(torch.tensor([6]), priority=1)) - futures.append(pools[1].submit_task(torch.tensor([7]), priority=11)) - futures.append(pools[1].submit_task(torch.tensor([8]), priority=1)) - for i, f in enumerate(futures): - assert f.result()[0].item() == i**2 - results_valid.set() - - proc = mp.Process(target=process_tasks) - proc.start() proc.join() assert results_valid.is_set() @@ -69,3 +78,5 @@ def test_priority_pools(): # 3 - task with priority 2 from pool A # 4 - task with priority 10 from pool A # 7 - task with priority 11 from pool B + + runtime.shutdown()