Fix cache again

pull/570/head
Artem Chumachenko 1 month ago
parent 46e29b230e
commit 0f498814be

@ -61,7 +61,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
]
outputs_ref = []
caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)]
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
caches = [cache, cache]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]

Loading…
Cancel
Save