Test repetition_penalty, refactor tests

pull/464/head
Aleksandr Borzunov 10 months ago
parent 30a8c22ca9
commit 99e7ecf25c

@ -103,12 +103,11 @@ def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
"input_ids"
]
options = dict(max_new_tokens=max_new_tokens, do_sample=False)
for multiple_calls in [False, True]:
for inputs in [inputs_single, inputs_batch]:
outputs = make_generate_calls(
model, inputs, max_new_tokens=max_new_tokens, multiple_calls=multiple_calls, do_sample=False
)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
@ -124,32 +123,34 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=4):
"input_ids"
]
for sampling_options in [
for options in [
dict(do_sample=True),
dict(do_sample=True, temperature=0.5),
dict(do_sample=True, temperature=0.5, top_k=5),
dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9, multiple_calls=True),
dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
]:
multiple_calls = sampling_options.pop("multiple_calls", False)
options.update(max_new_tokens=max_new_tokens)
multiple_calls = options.pop("multiple_calls", False)
for inputs in [inputs_single, inputs_batch]:
torch.manual_seed(0)
outputs = make_generate_calls(
model, inputs, max_new_tokens=max_new_tokens, multiple_calls=multiple_calls, **sampling_options
)
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
torch.manual_seed(0)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, **sampling_options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Sampling is not identical to HF with {sampling_options=}, {inputs.shape=}"
), f"Sampling is not identical to HF with {options=}, {inputs.shape=}"
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=6):
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
outputs = make_generate_calls(model, inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
outputs = make_generate_calls(model, inputs, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF with {multiple_calls=}"

Loading…
Cancel
Save