|
|
|
@ -126,6 +126,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
|
|
|
|
|
|
|
|
|
|
(outputs_ref * output_proj).sum().backward()
|
|
|
|
|
assert input_prompts_ref.grad is not None
|
|
|
|
|
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
|
|
|
|
|
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2)
|
|
|
|
|
assert intermediate_prompts_ref.grad is not None
|
|
|
|
|
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)
|
|
|
|
|