From c096ebdbbceebdc17363ba9fcfc67b6534c851dc Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 29 Nov 2022 20:45:49 +0300 Subject: [PATCH] black-isort --- tests/test_remote_sequential.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index cd87f68..2c08b9a 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -79,7 +79,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32) (outputs_ref,) = block(outputs_ref) - outputs_ref = (outputs_ref - torch.cat([inputs, input_prompts_ref], dim=1)) + torch.cat([inputs, input_prompts_ref], dim=1) + outputs_ref = (outputs_ref - torch.cat([inputs, input_prompts_ref], dim=1)) + torch.cat( + [inputs, input_prompts_ref], dim=1 + ) assert torch.allclose(outputs_ref, outputs) # exact match