mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-16 06:12:50 +00:00
Fix cache in tests
This commit is contained in:
parent
d41ff56047
commit
46e29b230e
@ -10,6 +10,7 @@ import torch
|
||||
from petals import AutoDistributedConfig
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.server.from_pretrained import load_pretrained_block
|
||||
from petals.utils.misc import DUMMY_KEY_PAST
|
||||
from test_utils import *
|
||||
|
||||
|
||||
@ -54,12 +55,13 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
|
||||
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
||||
outputs_inference = torch.cat(outputs_inference, dim=1)
|
||||
|
||||
dtype = torch.float32
|
||||
ref_blocks = [
|
||||
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
|
||||
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
|
||||
load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
|
||||
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
|
||||
]
|
||||
outputs_ref = []
|
||||
caches = [None, None]
|
||||
caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)]
|
||||
for i in range(inputs.shape[1]):
|
||||
new_caches = []
|
||||
hidden_states = inputs[:, i : i + 1, :]
|
||||
|
Loading…
Reference in New Issue
Block a user