pull/485/head
Aleksandr Borzunov 10 months ago
parent 13e61a7921
commit 7fe2635ee6

@ -157,7 +157,7 @@ def test_input_ids_and_embeds(tokenizer, model, ref_model, max_new_tokens=4):
assert inputs.keys() == {"input_ids", "attention_mask"}
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens)
ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):

Loading…
Cancel
Save