From a43265a6ea2ab69763eaff025f605bced4413765 Mon Sep 17 00:00:00 2001 From: artek0chumak Date: Tue, 17 Jan 2023 09:43:44 +0000 Subject: [PATCH] style --- tests/test_full_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 944b301..b24231e 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -14,7 +14,9 @@ logger = get_logger(__file__) @pytest.mark.forked @pytest.mark.parametrize("pass_empty_tensors", (True, False)) @pytest.mark.parametrize("second_token_attention_mask", (1, 0)) -def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3): +def test_full_model_exact_match( + pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3 +): tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 @@ -40,7 +42,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) for t in range(embs.shape[1]): - recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, :t+1])) + recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, : t + 1])) if t == int(embs.shape[1] // 2) and pass_empty_tensors: recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))