Fix cache in tests

This commit is contained in:
Artem Chumachenko 2024-04-08 19:27:10 +02:00
parent d41ff56047
commit 46e29b230e

View File

@ -10,6 +10,7 @@ import torch
from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from petals.utils.misc import DUMMY_KEY_PAST
from test_utils import *
@ -54,12 +55,13 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
dtype = torch.float32
ref_blocks = [
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
]
outputs_ref = []
caches = [None, None]
caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]