|
|
|
@ -45,13 +45,16 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
|
|
|
|
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
|
|
|
|
|
|
|
|
|
|
embs = model.word_embeddings(test_inputs)
|
|
|
|
|
norm_embs = model.word_embeddings_layernorm(embs.float())
|
|
|
|
|
embs = model.word_embeddings_layernorm(embs.float())
|
|
|
|
|
recurrent_outputs = []
|
|
|
|
|
with model.h.inference_session() as sess:
|
|
|
|
|
for t in range(norm_embs.shape[1]):
|
|
|
|
|
recurrent_outputs.append(sess.step(norm_embs[:, t: t + 1, :]))
|
|
|
|
|
for t in range(embs.shape[1]):
|
|
|
|
|
recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
|
|
|
|
|
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
|
|
|
|
|
recurrent_outputs = model.ln_f(recurrent_outputs)
|
|
|
|
|
recurrent_outputs = (recurrent_outputs.to(embs.dtype) @ embs.t()).float()
|
|
|
|
|
|
|
|
|
|
dictionary = model.word_embeddings.weight.t()
|
|
|
|
|
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
|
|
|
|
|
recurrent_outputs = (recurrent_outputs @ dictionary).float()
|
|
|
|
|
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
|
|
|
|
|
logger.info("Inference is consistent with forward")
|
|
|
|
|