|
|
|
@ -57,8 +57,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
|
|
|
|
|
|
|
|
|
|
input_prompts = input_prompts.detach().requires_grad_(True)
|
|
|
|
|
intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
intermediate_prompts[...] = torch.randn_like(intermediate_prompts)
|
|
|
|
|
|
|
|
|
|
inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
|
|
|
|
|
assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
|
|
|
|
|