pull/485/head
Aleksandr Borzunov 10 months ago
parent 06accbcf40
commit 4ed2d0bffb

@ -165,6 +165,7 @@ def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
[
model.generate(**inputs, max_new_tokens=2),
model.generate(None, max_new_tokens=max_new_tokens - 2),
]
],
dim=1,
)
assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"

Loading…
Cancel
Save