Fix test for multi-call generate()

pull/464/head
Aleksandr Borzunov 10 months ago
parent 56cfacdfb9
commit c066ddf06e

@ -78,8 +78,8 @@ def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empt
@pytest.mark.forked
@pytest.mark.parametrize("multiple_steps", [False, True])
def test_greedy_generation(tokenizer, model, ref_model, multiple_steps, max_new_tokens=4):
@pytest.mark.parametrize("multiple_calls", [False, True])
def test_greedy_generation(tokenizer, model, ref_model, multiple_calls, max_new_tokens=4):
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
if tokenizer.pad_token_id is None:
@ -89,15 +89,15 @@ def test_greedy_generation(tokenizer, model, ref_model, multiple_steps, max_new_
]
for inputs in [inputs_single, inputs_batch]:
if not multiple_steps:
if not multiple_calls:
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
else:
with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
outputs = [
# Sessions provided both explicitly and implicitly should work
model.generate(inputs, max_new_tokens=1, do_sample=False, session=sess),
model.generate(inputs, max_new_tokens=max_new_tokens - 2, do_sample=False),
model.generate(inputs, max_new_tokens=1, do_sample=False),
model.generate(None, max_new_tokens=max_new_tokens - 2, do_sample=False),
model.generate(None, max_new_tokens=1, do_sample=False),
]
outputs = torch.cat(outputs, dim=1)

Loading…
Cancel
Save