|
|
|
@ -150,16 +150,6 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=4):
|
|
|
|
|
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=6):
|
|
|
|
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
|
|
|
|
|
|
|
for multiple_calls in [False, True]:
|
|
|
|
|
outputs = make_generate_calls(
|
|
|
|
|
model,
|
|
|
|
|
inputs,
|
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
|
multiple_calls=multiple_calls,
|
|
|
|
|
num_beams=num_beams,
|
|
|
|
|
do_sample=False,
|
|
|
|
|
)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
|
|
|
|
|
assert torch.allclose(
|
|
|
|
|
outputs, ref_outputs
|
|
|
|
|
), f"Beam search results are not identical to HF with {multiple_calls=}"
|
|
|
|
|
outputs = make_generate_calls(model, inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
|
|
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF with {multiple_calls=}"
|
|
|
|
|