Fix flapping test

pull/485/head
Aleksandr Borzunov 10 months ago
parent b3136becf2
commit 68ff865660

@ -126,6 +126,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
(outputs_ref * output_proj).sum().backward()
assert input_prompts_ref.grad is not None
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2)
assert intermediate_prompts_ref.grad is not None
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)

Loading…
Cancel
Save