diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 6d17388..f6195d8 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -34,7 +34,7 @@ class _ServerInferenceSession: span: RemoteSpanInfo, span_uids: Sequence[ModuleUID], inputs_queue: asyncio.Queue, - outputs_stream: AsyncIterator, + outputs_aiter: AsyncIterator, *block_kwargs, max_length: int, ): @@ -42,7 +42,7 @@ class _ServerInferenceSession: self.span, self.span_uids = span, span_uids self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue - self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream + self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self.session_id = str(uuid.uuid4()) self.max_length = max_length self.stepped = False diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 1748490..6d450e3 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -233,7 +233,6 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches) - sequence_manager.rpc_info # lazy init #TODO no longer needed outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager)) assert len(outputs) == len(input_batches) diff --git a/tests/test_priority_pool.py b/tests/test_priority_pool.py index 8bf673c..15c6de5 100644 --- a/tests/test_priority_pool.py +++ b/tests/test_priority_pool.py @@ -4,8 +4,8 @@ import time import pytest import torch -from hivemind.moe.server.runtime import Runtime +from petals.server.server import RuntimeWithDeduplicatedPools from petals.server.task_pool import PrioritizedTaskPool @@ -35,8 +35,7 @@ def test_priority_pools(): runtime_ready = mp.Event() results_valid = mp.Event() - def dummy_pool_func(args, kwargs): - (x,) = args # TODO modify the PriorityPool code such that dummy_pool_func can accept x directly + def dummy_pool_func(x): time.sleep(0.1) y = x**2 outputs_queue.put((x, y)) @@ -58,7 +57,9 @@ def test_priority_pools(): proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid)) proc.start() - runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0) + runtime = RuntimeWithDeduplicatedPools( + {str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0 + ) runtime.ready = runtime_ready runtime.start()