diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index d4b012c..07994b2 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -10,6 +10,7 @@ import torch from petals import AutoDistributedConfig from petals.client.remote_sequential import RemoteSequential from petals.server.from_pretrained import load_pretrained_block +from petals.utils.misc import DUMMY_KEY_PAST from test_utils import * @@ -54,12 +55,13 @@ def test_chained_inference_exact_match(atol_inference=1e-4): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference = torch.cat(outputs_inference, dim=1) + dtype = torch.float32 ref_blocks = [ - load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), - load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32), + load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype), + load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype), ] outputs_ref = [] - caches = [None, None] + caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)] for i in range(inputs.shape[1]): new_caches = [] hidden_states = inputs[:, i : i + 1, :]