Materialize buffers in get_block_size() (#600)

pull/601/head
Alexander Borzunov 3 months ago committed by GitHub
parent 10f7525ce0
commit 103ef760da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -32,7 +32,7 @@ def get_block_size(
dtype is not None and quant_type is not None
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
with init_empty_weights(include_buffers=True):
with init_empty_weights(include_buffers=False):
block = get_model_block(config)
n_params = sum(param.numel() for param in block.parameters())

Loading…
Cancel
Save