|
|
|
@ -10,6 +10,7 @@ import torch
|
|
|
|
|
from petals import AutoDistributedConfig
|
|
|
|
|
from petals.client.remote_sequential import RemoteSequential
|
|
|
|
|
from petals.server.from_pretrained import load_pretrained_block
|
|
|
|
|
from petals.utils.misc import DUMMY_KEY_PAST
|
|
|
|
|
from test_utils import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -54,12 +55,13 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
|
|
|
|
|
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
|
|
|
|
outputs_inference = torch.cat(outputs_inference, dim=1)
|
|
|
|
|
|
|
|
|
|
dtype = torch.float32
|
|
|
|
|
ref_blocks = [
|
|
|
|
|
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
|
|
|
|
|
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
|
|
|
|
|
load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
|
|
|
|
|
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
|
|
|
|
|
]
|
|
|
|
|
outputs_ref = []
|
|
|
|
|
caches = [None, None]
|
|
|
|
|
caches = [DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)]
|
|
|
|
|
for i in range(inputs.shape[1]):
|
|
|
|
|
new_caches = []
|
|
|
|
|
hidden_states = inputs[:, i : i + 1, :]
|
|
|
|
|