pull/206/head
artek0chumak 1 year ago committed by Artem Chumachenko
parent d14debea35
commit a43265a6ea

@ -14,7 +14,9 @@ logger = get_logger(__file__)
@pytest.mark.forked
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
@pytest.mark.parametrize("second_token_attention_mask", (1, 0))
def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3):
def test_full_model_exact_match(
pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3
):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
@ -40,7 +42,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, :t+1]))
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, : t + 1]))
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))

Loading…
Cancel
Save