|
|
|
@ -4,8 +4,8 @@ import time
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
|
|
|
|
|
|
|
from petals.server.server import RuntimeWithDeduplicatedPools
|
|
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -35,8 +35,7 @@ def test_priority_pools():
|
|
|
|
|
runtime_ready = mp.Event()
|
|
|
|
|
results_valid = mp.Event()
|
|
|
|
|
|
|
|
|
|
def dummy_pool_func(args, kwargs):
|
|
|
|
|
(x,) = args # TODO modify the PriorityPool code such that dummy_pool_func can accept x directly
|
|
|
|
|
def dummy_pool_func(x):
|
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
y = x**2
|
|
|
|
|
outputs_queue.put((x, y))
|
|
|
|
@ -58,7 +57,9 @@ def test_priority_pools():
|
|
|
|
|
proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
|
|
|
|
|
proc.start()
|
|
|
|
|
|
|
|
|
|
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
|
|
|
|
|
runtime = RuntimeWithDeduplicatedPools(
|
|
|
|
|
{str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0
|
|
|
|
|
)
|
|
|
|
|
runtime.ready = runtime_ready
|
|
|
|
|
runtime.start()
|
|
|
|
|
|
|
|
|
|