@ -61,7 +61,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
]
outputs_ref = []
caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)]
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
caches = [cache, cache]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]