Increase tolerances in test_tp_block (#196)

deflapify tests
pull/199/head^2
justheuristic 1 year ago committed by GitHub
parent b4f3224cda
commit c2cb6d19ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,7 +40,7 @@ def test_tp_block(devices, custom_config):
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
y_ours.backward(grad_proj)
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
assert torch.allclose(y_ours, y_ref, atol=1e-6)
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
assert torch.allclose(y_ours, y_ref, atol=1e-5)
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)

Loading…
Cancel
Save