From 0f498814be57eb87529183a2dced754cd319cbe5 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Mon, 8 Apr 2024 19:34:47 +0200 Subject: [PATCH] Fix cache again --- tests/test_chained_calls.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 07994b2..e8b492a 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -61,7 +61,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4): load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype), ] outputs_ref = [] - caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)] + cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)) + caches = [cache, cache] for i in range(inputs.shape[1]): new_caches = [] hidden_states = inputs[:, i : i + 1, :]