|
|
|
@ -103,12 +103,11 @@ def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
|
|
|
|
|
"input_ids"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
options = dict(max_new_tokens=max_new_tokens, do_sample=False)
|
|
|
|
|
for multiple_calls in [False, True]:
|
|
|
|
|
for inputs in [inputs_single, inputs_batch]:
|
|
|
|
|
outputs = make_generate_calls(
|
|
|
|
|
model, inputs, max_new_tokens=max_new_tokens, multiple_calls=multiple_calls, do_sample=False
|
|
|
|
|
)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
|
|
|
|
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, **options)
|
|
|
|
|
assert torch.allclose(
|
|
|
|
|
outputs, ref_outputs
|
|
|
|
|
), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
|
|
|
|
@ -124,32 +123,34 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=4):
|
|
|
|
|
"input_ids"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for sampling_options in [
|
|
|
|
|
for options in [
|
|
|
|
|
dict(do_sample=True),
|
|
|
|
|
dict(do_sample=True, temperature=0.5),
|
|
|
|
|
dict(do_sample=True, temperature=0.5, top_k=5),
|
|
|
|
|
dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
|
|
|
|
|
dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9, multiple_calls=True),
|
|
|
|
|
dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
|
|
|
|
|
]:
|
|
|
|
|
multiple_calls = sampling_options.pop("multiple_calls", False)
|
|
|
|
|
options.update(max_new_tokens=max_new_tokens)
|
|
|
|
|
multiple_calls = options.pop("multiple_calls", False)
|
|
|
|
|
|
|
|
|
|
for inputs in [inputs_single, inputs_batch]:
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
outputs = make_generate_calls(
|
|
|
|
|
model, inputs, max_new_tokens=max_new_tokens, multiple_calls=multiple_calls, **sampling_options
|
|
|
|
|
)
|
|
|
|
|
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, **sampling_options)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, **options)
|
|
|
|
|
|
|
|
|
|
assert torch.allclose(
|
|
|
|
|
outputs, ref_outputs
|
|
|
|
|
), f"Sampling is not identical to HF with {sampling_options=}, {inputs.shape=}"
|
|
|
|
|
), f"Sampling is not identical to HF with {options=}, {inputs.shape=}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
|
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=6):
|
|
|
|
|
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
|
|
|
|
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
|
|
|
|
|
|
|
outputs = make_generate_calls(model, inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
|
|
|
|
|
options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
|
|
|
|
|
outputs = make_generate_calls(model, inputs, **options)
|
|
|
|
|
ref_outputs = ref_model.generate(inputs, **options)
|
|
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF with {multiple_calls=}"
|
|
|
|
|