Fix beam search in GPU clients (#531)

Fixes #503.
pull/523/head
Alexander Borzunov 7 months ago committed by GitHub
parent 47d50e1e29
commit 82a97d6e9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,7 +48,6 @@ jobs:
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
@ -61,27 +60,25 @@ jobs:
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
--device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for the 1st server to choose blocks
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
--identity_path tests/server2.id \
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
$RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
SERVER2_PID=$!
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
SERVER3_PID=$!
# ^-- chunking test
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
$RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)
@ -121,4 +118,3 @@ jobs:
# [Step 4] Clean up
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
echo "Done!"

@ -84,12 +84,7 @@ class _ServerInferenceSession:
break # this message means "done sending"
def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
*,
step_id: str,
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
@ -114,21 +109,6 @@ class _ServerInferenceSession:
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
@ -275,7 +255,9 @@ class InferenceSession:
assert not self._closed and not self._server_sessions
return self
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
def step(
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@ -285,11 +267,21 @@ class InferenceSession:
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
hypo_ids = hypo_ids.cpu()
step_id = str(uuid.uuid4())
n_input_tokens = inputs.shape[1]
@ -310,7 +302,7 @@ class InferenceSession:
server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
)
server_idx += 1

Loading…
Cancel
Save