|
|
|
@ -118,11 +118,11 @@ def test_falcon(device):
|
|
|
|
|
unopt_block.load_state_dict(block.state_dict())
|
|
|
|
|
cache = unopt_cache = None
|
|
|
|
|
|
|
|
|
|
for l in range(3):
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype)
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
for length in [10, 1, 1, 1]:
|
|
|
|
|
dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
|
|
|
|
|
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
|
|
|
|
|
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
|
|
|
|
|
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l
|
|
|
|
|
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l
|
|
|
|
|
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l
|
|
|
|
|
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
|
|
|
|
|
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
|
|
|
|
|
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length
|
|
|
|
|