From 454c193863eed5d06ccf2c33f5187c6313ffd1bb Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 25 Apr 2023 17:20:19 +0400 Subject: [PATCH] Fix OOMs happening in case of accelerate >= 0.16.0 (#310) - After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()` - In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (https://github.com/huggingface/accelerate/pull/920) - Because of that, blocks and attention caches used float32, which caused OOMs - This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`) --- setup.cfg | 2 +- src/petals/bloom/from_pretrained.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 09182c6..786c8f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ python_requires = >=3.7 install_requires = torch>=1.12 bitsandbytes==0.38.0.post2 - accelerate>=0.15.0,<1.0.0 + accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 9f1d12b..4748b41 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -68,7 +68,7 @@ def load_pretrained_block( param = state_dict[param_name] if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): param = param.to(torch_dtype) - set_module_tensor_to_device(block, param_name, "cpu", value=param) + set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}") return block