diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 29be828..1c68051 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -142,8 +142,8 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10): @pytest.mark.skipif( - MODEL_NAME.lower().find("mixtral") > -1, - reason="Mixtral uses DynamicCache, which can change based on beam search choices", + MODEL_NAME.lower().find("bloom") == -1, + reason="Mixtral and Llama uses DynamicCache, which can change based on beam search choices", ) @pytest.mark.forked def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):