Fix OOMs happening in case of accelerate >= 0.16.0 (#310)

- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()`
- In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (https://github.com/huggingface/accelerate/pull/920)
- Because of that, blocks and attention caches used float32, which caused OOMs
- This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
pull/311/head
Alexander Borzunov 1 year ago committed by GitHub
parent 93c4eba5d1
commit 454c193863
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@ python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.38.0.post2
accelerate>=0.15.0,<1.0.0
accelerate>=0.16.0,<1.0.0
huggingface-hub>=0.11.1,<1.0.0
transformers>=4.25.1,<5.0.0
speedtest-cli==2.1.3

@ -68,7 +68,7 @@ def load_pretrained_block(
param = state_dict[param_name]
if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
return block

Loading…
Cancel
Save