Fix OOMs happening in case of accelerate >= 0.16.0 (#310)
- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()` - In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (https://github.com/huggingface/accelerate/pull/920) - Because of that, blocks and attention caches used float32, which caused OOMs - This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)pull/311/head
parent
93c4eba5d1
commit
454c193863
Loading…
Reference in New Issue