Update test_optimized_layers

pull/500/head
Max Ryabinin 9 months ago
parent 2c27c19df4
commit 52baffb056

@ -118,11 +118,11 @@ def test_falcon(device):
unopt_block.load_state_dict(block.state_dict())
cache = unopt_cache = None
for l in range(3):
with torch.inference_mode():
dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype)
with torch.inference_mode():
for length in [10, 1, 1, 1]:
dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length

Loading…
Cancel
Save