reduce diff

pull/467/head
Your Name 8 months ago
parent 3f06b53b1d
commit c665c42cf2

@ -34,7 +34,7 @@ class _ServerInferenceSession:
span: RemoteSpanInfo,
span_uids: Sequence[ModuleUID],
inputs_queue: asyncio.Queue,
outputs_stream: AsyncIterator,
outputs_aiter: AsyncIterator,
*block_kwargs,
max_length: int,
):
@ -42,7 +42,7 @@ class _ServerInferenceSession:
self.span, self.span_uids = span, span_uids
self.num_blocks = len(span_uids)
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.session_id = str(uuid.uuid4())
self.max_length = max_length
self.stepped = False

@ -233,7 +233,6 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches)
sequence_manager.rpc_info # lazy init #TODO no longer needed
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
assert len(outputs) == len(input_batches)

@ -4,8 +4,8 @@ import time
import pytest
import torch
from hivemind.moe.server.runtime import Runtime
from petals.server.server import RuntimeWithDeduplicatedPools
from petals.server.task_pool import PrioritizedTaskPool
@ -35,8 +35,7 @@ def test_priority_pools():
runtime_ready = mp.Event()
results_valid = mp.Event()
def dummy_pool_func(args, kwargs):
(x,) = args # TODO modify the PriorityPool code such that dummy_pool_func can accept x directly
def dummy_pool_func(x):
time.sleep(0.1)
y = x**2
outputs_queue.put((x, y))
@ -58,7 +57,9 @@ def test_priority_pools():
proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
proc.start()
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
runtime = RuntimeWithDeduplicatedPools(
{str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0
)
runtime.ready = runtime_ready
runtime.start()

Loading…
Cancel
Save