diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index b9dcc01..74b731d 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -48,7 +48,6 @@ jobs: export MODEL_NAME="${{ matrix.model }}" export REF_NAME="${{ matrix.model }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" - export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) @@ -61,27 +60,25 @@ jobs: until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ - --mean_balance_check_period 10 \ - --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log & + export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \ + --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS" + export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" + + $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log & SERVER1_PID=$! # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there sleep 10 # wait for the 1st server to choose blocks - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \ - --identity_path tests/server2.id \ - --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log & + $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log & SERVER2_PID=$! - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \ - --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ - --initial_peers $INITIAL_PEERS --throughput auto &> server3.log & + $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \ + --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log & SERVER3_PID=$! # ^-- chunking test - python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \ - --initial_peers $INITIAL_PEERS --throughput auto &> server4.log & + $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log & SERVER4_PID=$! # ^-- tensor parallelism test (not compatible with adapters yet) @@ -121,4 +118,3 @@ jobs: # [Step 4] Clean up kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID - echo "Done!" diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 28d3632..34d24c7 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -84,12 +84,7 @@ class _ServerInferenceSession: break # this message means "done sending" def step( - self, - inputs: torch.Tensor, - prompts: Optional[torch.Tensor] = None, - hypo_ids: Optional[torch.Tensor] = None, - *, - step_id: str, + self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -114,21 +109,6 @@ class _ServerInferenceSession: else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further - if prompts is None or is_dummy(prompts): - prompts = DUMMY - else: - assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" - assert prompts.shape[0] == self.num_blocks - assert prompts.shape[1] in (inputs.shape[0], 1) - assert prompts.shape[2] <= inputs.shape[1] - assert prompts.shape[3] == inputs.shape[2] - - if hypo_ids is None or is_dummy(hypo_ids): - hypo_ids = DUMMY_INT64 - else: - assert len(hypo_ids) == len(inputs) - assert hypo_ids.dtype == torch.int64 - # serialize inputs and put them into the queue input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids) @@ -275,7 +255,9 @@ class InferenceSession: assert not self._closed and not self._server_sessions return self - def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + def step( + self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") @@ -285,11 +267,21 @@ class InferenceSession: else: assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks + assert prompts.shape[1] in (inputs.shape[0], 1) + assert prompts.shape[2] <= inputs.shape[1] + assert prompts.shape[3] == inputs.shape[2] + + if hypo_ids is None or is_dummy(hypo_ids): + hypo_ids = DUMMY_INT64 + else: + assert len(hypo_ids) == len(inputs) + assert hypo_ids.dtype == torch.int64 inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() + hypo_ids = hypo_ids.cpu() step_id = str(uuid.uuid4()) n_input_tokens = inputs.shape[1] @@ -310,7 +302,7 @@ class InferenceSession: server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs + inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id ) server_idx += 1