|
|
|
@ -13,7 +13,8 @@ logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
|
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
|
|
|
|
|
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
|
|
|
|
|
@pytest.mark.parametrize("second_token_attention_mask", (1, 0))
|
|
|
|
|
def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3):
|
|
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
@ -23,9 +24,11 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
|
|
|
|
|
assert len(model.transformer.h) == model.config.n_layer
|
|
|
|
|
|
|
|
|
|
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
|
|
attention_mask = torch.ones_like(test_inputs)
|
|
|
|
|
attention_mask[0, 1] = second_token_attention_mask
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
parallel_outputs = model.forward(test_inputs).logits
|
|
|
|
|
parallel_outputs = model.forward(test_inputs, attention_mask=attention_mask).logits
|
|
|
|
|
assert torch.all(torch.isfinite(parallel_outputs))
|
|
|
|
|
logger.info("Forward outputs are finite")
|
|
|
|
|
|
|
|
|
@ -37,7 +40,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
|
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
|
|
|
|
|
|
for t in range(embs.shape[1]):
|
|
|
|
|
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
|
|
|
|
|
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, :t+1]))
|
|
|
|
|
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
|
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
@ -58,13 +61,10 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
|
|
|
|
|
ref_model.resize_token_embeddings(config.vocab_size)
|
|
|
|
|
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
|
|
|
|
|
|
|
|
|
|
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
|
|
|
|
# note: this creates a dummy mask to make the test compatible with older transformer versions
|
|
|
|
|
# prior to https://github.com/huggingface/transformers/pull/17837
|
|
|
|
|
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
|
|
|
|
|
ref_outputs = ref_model.forward(test_inputs, attention_mask=attention_mask).logits.float()
|
|
|
|
|
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
|
|
|
|
|
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
|
|
|
|
|
del ref_model, ref_outputs, dummy_mask
|
|
|
|
|
del ref_model, ref_outputs
|
|
|
|
|
else:
|
|
|
|
|
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
|
|
|
|
|
assert False
|
|
|
|
|