|
|
|
@ -40,7 +40,7 @@ def test_tp_block(devices, custom_config):
|
|
|
|
|
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
|
|
|
|
|
y_ours.backward(grad_proj)
|
|
|
|
|
|
|
|
|
|
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
|
|
|
|
|
assert torch.allclose(y_ours, y_ref, atol=1e-6)
|
|
|
|
|
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
|
|
|
|
|
assert torch.allclose(y_ours, y_ref, atol=1e-5)
|
|
|
|
|
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
|
|
|
|
|
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)
|
|
|
|
|