|
|
|
@ -142,8 +142,8 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
|
MODEL_NAME.lower().find("mixtral") > -1,
|
|
|
|
|
reason="Mixtral uses DynamicCache, which can change based on beam search choices",
|
|
|
|
|
MODEL_NAME.lower().find("bloom") == -1,
|
|
|
|
|
reason="Mixtral and Llama uses DynamicCache, which can change based on beam search choices",
|
|
|
|
|
)
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
|
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
|
|
|
|
|