Fix cache again

pull/570/head
Artem Chumachenko 2 months ago
parent 46e29b230e
commit 0f498814be

@ -61,7 +61,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype), load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
] ]
outputs_ref = [] outputs_ref = []
caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)] cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
caches = [cache, cache]
for i in range(inputs.shape[1]): for i in range(inputs.shape[1]):
new_caches = [] new_caches = []
hidden_states = inputs[:, i : i + 1, :] hidden_states = inputs[:, i : i + 1, :]

Loading…
Cancel
Save