|
|
@ -61,7 +61,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
|
|
|
|
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
|
|
|
|
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
|
|
|
|
]
|
|
|
|
]
|
|
|
|
outputs_ref = []
|
|
|
|
outputs_ref = []
|
|
|
|
caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)]
|
|
|
|
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
|
|
|
|
|
|
|
|
caches = [cache, cache]
|
|
|
|
for i in range(inputs.shape[1]):
|
|
|
|
for i in range(inputs.shape[1]):
|
|
|
|
new_caches = []
|
|
|
|
new_caches = []
|
|
|
|
hidden_states = inputs[:, i : i + 1, :]
|
|
|
|
hidden_states = inputs[:, i : i + 1, :]
|
|
|
|