diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index cd87f68..2c08b9a 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -79,7 +79,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32) (outputs_ref,) = block(outputs_ref) - outputs_ref = (outputs_ref - torch.cat([inputs, input_prompts_ref], dim=1)) + torch.cat([inputs, input_prompts_ref], dim=1) + outputs_ref = (outputs_ref - torch.cat([inputs, input_prompts_ref], dim=1)) + torch.cat( + [inputs, input_prompts_ref], dim=1 + ) assert torch.allclose(outputs_ref, outputs) # exact match