petals/tests/test_priority_pool.py
Alexander Borzunov 7bd5916744
Make Petals a pip-installable package (attempt 2) (#102)
1. Petals can be now installed using `pip install git+https://github.com/bigscience-workshop/petals`
    - In case if you already cloned the repo, you can do `pip install .` or `pip install .[dev]`
2. Moved `src` => `src/petals`
    - Replaced `from src.smth import smth` with `from petals.smth import smth`
3. Moved `cli` => `src/petals/cli`
    - Replaced `python -m cli.run_smth` with `python -m petals.cli.run_smth` (all utilities are now available right after pip installation)
4. Moved the `requirements*.txt` contents to `setup.cfg` (`requirements.txt` for packages is not supported well by modern packaging utils)
5. Increased the package version from `0.2` to `1.0alpha1`
2022-11-30 10:41:13 +04:00

72 lines
2.8 KiB
Python

import multiprocessing as mp
import time
import pytest
import torch
from petals.server.runtime import Runtime
from petals.server.task_pool import PrioritizedTaskPool
@pytest.mark.forked
def test_priority_pools():
outputs_queue = mp.SimpleQueue()
results_valid = mp.Event()
def dummy_pool_func(x):
time.sleep(0.1)
y = x**2
outputs_queue.put((x, y))
return (y,)
class DummyBackend:
def __init__(self, pools):
self.pools = pools
def get_pools(self):
return self.pools
pools = (
PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
)
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
runtime.start()
def process_tasks():
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
proc = mp.Process(target=process_tasks)
proc.start()
proc.join()
assert results_valid.is_set()
ordered_outputs = []
while not outputs_queue.empty():
ordered_outputs.append(outputs_queue.get()[0].item())
assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
# 0 - first batch is loaded immediately, before everything else
# 5 - highest priority task overall
# 1 - first of several tasks with equal lowest priority (1)
# 2 - second earliest task with priority 1, fetched from pool B
# 6 - third earliest task with priority 1, fetched from pool A again
# 8 - last priority-1 task, pool B
# 3 - task with priority 2 from pool A
# 4 - task with priority 10 from pool A
# 7 - task with priority 11 from pool B