Warn that multi-call beam search doesn't work properly

pull/464/head
Aleksandr Borzunov 10 months ago
parent 4d1c228dd6
commit 30a8c22ca9

@ -78,7 +78,14 @@ class RemoteGenerationMixin:
with context_manager as session:
# Prepend the last tokens from the previous .generate() call
if session.last_token_id is not None:
resuming_session = session.last_token_id is not None
if resuming_session:
if kwargs.get("num_beams", 1) > 1:
logger.warning(
"Beam search will not work properly in the resumed petals.InferenceSession "
"since intermediate beam entries are lost"
)
assert session.last_token_id.shape[1] == 1, f"{session.last_token_id.shape=} is invalid"
if inputs is not None:
inputs = torch.cat([session.last_token_id, inputs], dim=1)
@ -89,7 +96,7 @@ class RemoteGenerationMixin:
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Crop the last tokens from the previous call
if session.last_token_id is not None:
if resuming_session:
sequences = sequences[:, 1:]
if isinstance(result, ModelOutput):
result.sequences = sequences

@ -150,16 +150,6 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=4):
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=6):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
for multiple_calls in [False, True]:
outputs = make_generate_calls(
model,
inputs,
max_new_tokens=max_new_tokens,
multiple_calls=multiple_calls,
num_beams=num_beams,
do_sample=False,
)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
assert torch.allclose(
outputs, ref_outputs
), f"Beam search results are not identical to HF with {multiple_calls=}"
outputs = make_generate_calls(model, inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF with {multiple_calls=}"

Loading…
Cancel
Save