Update convert_8bit.py

readme-clarifications
justheuristic 2 years ago committed by GitHub
parent 8f34b92b68
commit 0e5e93af7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,11 @@
import bitsandbytes as bnb
import os
import torch
PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 0)))
def replace_8bit_linear(model, threshold=6.0):
"""
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
@ -29,6 +33,7 @@ def replace_8bit_linear(model, threshold=6.0):
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
memory_efficient_backward=PETALS_8BIT_BACKWARD,
)
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False

Loading…
Cancel
Save