|
|
|
@ -165,6 +165,7 @@ def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
|
|
|
|
|
[
|
|
|
|
|
model.generate(**inputs, max_new_tokens=2),
|
|
|
|
|
model.generate(None, max_new_tokens=max_new_tokens - 2),
|
|
|
|
|
]
|
|
|
|
|
],
|
|
|
|
|
dim=1,
|
|
|
|
|
)
|
|
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"
|
|
|
|
|