|
|
|
@ -161,10 +161,12 @@ def test_input_ids_and_embeds(tokenizer, model, ref_model, max_new_tokens=4):
|
|
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
|
|
|
|
|
|
|
|
|
|
with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
|
|
|
|
|
outputs = torch.cat([
|
|
|
|
|
model.generate(**inputs, max_new_tokens=2),
|
|
|
|
|
model.generate(None, max_new_tokens=max_new_tokens - 2),
|
|
|
|
|
])
|
|
|
|
|
outputs = torch.cat(
|
|
|
|
|
[
|
|
|
|
|
model.generate(**inputs, max_new_tokens=2),
|
|
|
|
|
model.generate(None, max_new_tokens=max_new_tokens - 2),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"
|
|
|
|
|
|
|
|
|
|
inputs_embeds = model.transformer.word_embeddings(inputs["input_ids"])
|
|
|
|
@ -173,8 +175,10 @@ def test_input_ids_and_embeds(tokenizer, model, ref_model, max_new_tokens=4):
|
|
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
|
|
|
|
|
|
|
|
|
|
with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
|
|
|
|
|
outputs = torch.cat([
|
|
|
|
|
model.generate(inputs_embeds=inputs_embeds, max_new_tokens=2),
|
|
|
|
|
model.generate(None, max_new_tokens=max_new_tokens - 2),
|
|
|
|
|
])
|
|
|
|
|
outputs = torch.cat(
|
|
|
|
|
[
|
|
|
|
|
model.generate(inputs_embeds=inputs_embeds, max_new_tokens=2),
|
|
|
|
|
model.generate(None, max_new_tokens=max_new_tokens - 2),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
|
|
|
|
|