|
|
|
@ -8,7 +8,7 @@ from hivemind.moe.server.runtime import Runtime
|
|
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_tasks(pools, results_valid):
|
|
|
|
|
def _submit_tasks(pools, results_valid):
|
|
|
|
|
torch.set_num_threads(1)
|
|
|
|
|
|
|
|
|
|
futures = []
|
|
|
|
@ -50,11 +50,13 @@ def test_priority_pools():
|
|
|
|
|
PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Simulate requests coming from ConnectionHandlers
|
|
|
|
|
proc = mp.context.ForkProcess(target=_submit_tasks, args=(pools, results_valid))
|
|
|
|
|
proc.start()
|
|
|
|
|
|
|
|
|
|
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
|
|
|
|
|
runtime.start()
|
|
|
|
|
|
|
|
|
|
proc = mp.context.ForkProcess(target=_process_tasks, args=(pools, results_valid))
|
|
|
|
|
proc.start()
|
|
|
|
|
proc.join()
|
|
|
|
|
assert results_valid.is_set()
|
|
|
|
|
|
|
|
|
|