Support macOS (#477)

This PR makes both clients and servers work on macOS. Specifically, it:

- Follows https://github.com/learning-at-home/hivemind/pull/586 to run a macOS-compatible `p2pd` binary (both x86-64 and ARM64 are supported)
- Fixes forking issues and tests on macOS, Python 3.10+
- Introduces basic support for serving model blocks on Apple M1/M2 GPUs (torch.mps)
- Increases max number of open files by default (it's not enough on Linux and is really small on macOS)
pull/484/head
Alexander Borzunov 9 months ago committed by GitHub
parent 75e516a8c1
commit 26ebbfe8f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,20 +7,21 @@ on:
jobs: jobs:
run-tests: run-tests:
runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
include: include:
- { model: 'bigscience/bloom-560m', python-version: '3.8' } - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', python-version: '3.9' } - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'bigscience/bloom-560m', python-version: '3.10' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Increase swap space - name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
uses: pierotofy/set-swap-space@master uses: pierotofy/set-swap-space@master
with: with:
swap-size-gb: 10 swap-size-gb: 10
@ -47,12 +48,7 @@ jobs:
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# [Step 1] Watch free RAM (lack of RAM is a common issue in CI) # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
bash -c 'while true; do free -h && sleep 30s; done' &
RAM_WATCH_PID=$!
# [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
python -m petals.cli.run_dht \ python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log & --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
@ -61,7 +57,7 @@ jobs:
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
sleep 5 # wait for DHT init until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \ --mean_balance_check_period 10 \
@ -95,11 +91,15 @@ jobs:
sleep 30 # wait for servers to eval throughput, download layers, and rebalance sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
# [Step 3] Run PyTest # [Step 2] Run PyTest
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
pytest tests --durations=0 --durations-min=1.0 -v pytest tests --durations=0 --durations-min=1.0 -v
# [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --seq_len 3
@ -110,9 +110,7 @@ jobs:
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
# [Step 5] Clean up # [Step 4] Clean up
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
echo "Done!" echo "Done!"

@ -51,7 +51,7 @@ python -m petals.cli.run_server petals-team/StableBeluga2
🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki. 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD): 🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
```bash ```bash
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
@ -59,12 +59,20 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach
python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2 python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
``` ```
🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
```bash
brew install python
python3 -m pip install git+https://github.com/bigscience-workshop/petals
python3 -m petals.cli.run_server petals-team/StableBeluga2
```
<p align="center"> <p align="center">
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (using multiple GPUs, starting on boot, etc.) 📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
💬 &nbsp;<b><a href="https://discord.gg/X7DgtxgMhc">Ask for help in Discord</a></b>
</p> </p>
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command. 🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).

@ -18,6 +18,7 @@ classifiers =
Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Mathematics Topic :: Scientific/Engineering :: Mathematics
Topic :: Scientific/Engineering :: Artificial Intelligence Topic :: Scientific/Engineering :: Artificial Intelligence
@ -39,7 +40,7 @@ install_requires =
transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3 speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.9 hivemind @ git+https://github.com/learning-at-home/hivemind
tensor_parallel==1.0.23 tensor_parallel==1.0.23
humanfriendly humanfriendly
async-timeout>=4.0.2 async-timeout>=4.0.2

@ -1,7 +1,13 @@
import os import os
import platform
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
if platform.system() == "Darwin":
# Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
os.environ.setdefault("no_proxy", "*")
os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
import hivemind import hivemind
import transformers import transformers
from packaging import version from packaging import version

@ -1,8 +1,10 @@
import argparse import argparse
import logging
import configargparse import configargparse
import torch
from hivemind.proto.runtime_pb2 import CompressionType from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit from hivemind.utils import limits
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from humanfriendly import parse_size from humanfriendly import parse_size
@ -127,9 +129,9 @@ def main():
group.add_argument('--new_swarm', action='store_true', group.add_argument('--new_swarm', action='store_true',
help='Start a new private swarm (i.e., do not connect to any initial peers)') help='Start a new private swarm (i.e., do not connect to any initial peers)')
parser.add_argument('--increase_file_limit', action='store_true', parser.add_argument('--increase_file_limit', type=int, default=4096,
help='On *nix, this will increase the max number of processes ' help='On *nix, increase the max number of files a server can open '
'a server can spawn before hitting "Too many open files"; Use at your own risk.') 'before hitting "Too many open files" (set to zero to keep the system limit)')
parser.add_argument('--stats_report_interval', type=int, required=False, parser.add_argument('--stats_report_interval', type=int, required=False,
help='Interval between two reports of batch processing performance statistics') help='Interval between two reports of batch processing performance statistics')
@ -185,8 +187,10 @@ def main():
args["startup_timeout"] = args.pop("daemon_startup_timeout") args["startup_timeout"] = args.pop("daemon_startup_timeout")
if args.pop("increase_file_limit"): file_limit = args.pop("increase_file_limit")
increase_file_limit() if file_limit:
limits.logger.setLevel(logging.WARNING)
limits.increase_file_limit(file_limit, file_limit)
compression_type = args.pop("compression").upper() compression_type = args.pop("compression").upper()
compression = getattr(CompressionType, compression_type) compression = getattr(CompressionType, compression_type)
@ -207,6 +211,10 @@ def main():
validate_version() validate_version()
if not torch.backends.openmp.is_available():
# Necessary to prevent the server from freezing after forks
torch.set_num_threads(1)
server = Server( server = Server(
**args, **args,
host_maddrs=host_maddrs, host_maddrs=host_maddrs,

@ -140,7 +140,7 @@ class ReachabilityProtocol(ServicerBase):
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS) protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
ready.set_result(True) ready.set_result(True)
logger.info("Reachability service started") logger.debug("Reachability service started")
async with protocol.serve(common_p2p): async with protocol.serve(common_p2p):
await protocol._stop.wait() await protocol._stop.wait()

@ -9,7 +9,9 @@ import time
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Union
import hivemind import hivemind
import psutil
import torch import torch
import torch.mps
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime from hivemind.moe.server.runtime import Runtime
@ -154,13 +156,25 @@ class Server:
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = torch.device(device) device = torch.device(device)
if device.type == "cuda" and device.index is None: if device.type == "cuda" and device.index is None:
device = torch.device(device.type, index=0) device = torch.device(device.type, index=0)
self.device = device self.device = device
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype]) torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
if device.type == "cpu" and torch_dtype == torch.float16:
raise ValueError(
f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
)
if device.type == "mps" and torch_dtype == torch.bfloat16:
logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
torch_dtype = torch.float16
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
if tensor_parallel_devices is None: if tensor_parallel_devices is None:
@ -253,13 +267,14 @@ class Server:
self.stop = threading.Event() self.stop = threading.Event()
def _choose_num_blocks(self) -> int: def _choose_num_blocks(self) -> int:
assert self.device.type == "cuda", ( assert self.device.type in ("cuda", "mps"), (
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. " "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
"CPU-only servers in the public swarm are discouraged since they are much slower" "CPU-only servers in the public swarm are discouraged since they are much slower"
) )
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
if num_devices > 1: if num_devices > 1:
assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
memory_per_device = tuple( memory_per_device = tuple(
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
) )
@ -270,8 +285,10 @@ class Server:
"Please launch individual servers on each GPU or set --num_blocks manually to " "Please launch individual servers on each GPU or set --num_blocks manually to "
"override this exception." "override this exception."
) )
else: elif self.device.type == "cuda":
total_memory = torch.cuda.get_device_properties(self.device).total_memory total_memory = torch.cuda.get_device_properties(self.device).total_memory
else:
total_memory = psutil.virtual_memory().total
gib = 1024**3 gib = 1024**3
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
@ -373,6 +390,8 @@ class Server:
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, " f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
f"{reserved_vram / gib:.1f} GiB reserved memory" f"{reserved_vram / gib:.1f} GiB reserved memory"
) )
elif self.device.type == "mps":
torch.mps.empty_cache()
def _choose_blocks(self) -> List[int]: def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None: if self.strict_block_indices is not None:

@ -9,6 +9,7 @@ from pathlib import Path
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
import torch import torch
import torch.mps
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -207,14 +208,12 @@ def measure_compute_rps(
elapsed = 0 elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
if device.type == "cuda": synchronize(device)
torch.cuda.synchronize(device)
start_time = time.perf_counter() start_time = time.perf_counter()
for step in range(n_steps): for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
if device.type == "cuda": synchronize(device)
torch.cuda.synchronize(device)
elapsed = time.perf_counter() - start_time elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed device_rps = n_steps * n_tokens / elapsed
@ -230,8 +229,15 @@ def measure_compute_rps(
return device_rps return device_rps
def synchronize(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize(device)
elif device.type == "mps":
torch.mps.synchronize()
def get_device_name(device: torch.device) -> str: def get_device_name(device: torch.device) -> str:
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU" return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

@ -118,7 +118,7 @@ async def test_cache_usage():
allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache
await allocate_f_task await allocate_f_task
alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True) alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
alloc_process1.start() alloc_process1.start()
async def _allocate_bcde(): async def _allocate_bcde():
@ -128,7 +128,7 @@ async def test_cache_usage():
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED) await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True) alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
alloc_process2.start() alloc_process2.start()
assert cache.current_size_bytes == 0 assert cache.current_size_bytes == 0
alloc_event.set() alloc_event.set()

@ -1,4 +1,5 @@
import multiprocessing as mp import multiprocessing as mp
import platform
import time import time
import pytest import pytest
@ -8,9 +9,30 @@ from hivemind.moe.server.runtime import Runtime
from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_pool import PrioritizedTaskPool
def _submit_tasks(runtime_ready, pools, results_valid):
runtime_ready.wait()
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
@pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks")
@pytest.mark.forked @pytest.mark.forked
def test_priority_pools(): def test_priority_pools():
outputs_queue = mp.SimpleQueue() outputs_queue = mp.SimpleQueue()
runtime_ready = mp.Event()
results_valid = mp.Event() results_valid = mp.Event()
def dummy_pool_func(x): def dummy_pool_func(x):
@ -31,27 +53,14 @@ def test_priority_pools():
PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1), PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
) )
# Simulate requests coming from ConnectionHandlers
proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
proc.start()
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0) runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
runtime.ready = runtime_ready
runtime.start() runtime.start()
def process_tasks():
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
proc = mp.Process(target=process_tasks)
proc.start()
proc.join() proc.join()
assert results_valid.is_set() assert results_valid.is_set()
@ -69,3 +78,5 @@ def test_priority_pools():
# 3 - task with priority 2 from pool A # 3 - task with priority 2 from pool A
# 4 - task with priority 10 from pool A # 4 - task with priority 10 from pool A
# 7 - task with priority 11 from pool B # 7 - task with priority 11 from pool B
runtime.shutdown()

Loading…
Cancel
Save