|
|
|
@ -14,7 +14,9 @@ logger = get_logger(__file__)
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
|
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
|
|
|
|
|
@pytest.mark.parametrize("second_token_attention_mask", (1, 0))
|
|
|
|
|
def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3):
|
|
|
|
|
def test_full_model_exact_match(
|
|
|
|
|
pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3
|
|
|
|
|
):
|
|
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
@ -40,7 +42,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention
|
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
|
|
|
|
|
|
for t in range(embs.shape[1]):
|
|
|
|
|
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, :t+1]))
|
|
|
|
|
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, : t + 1]))
|
|
|
|
|
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
|
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
|