|
|
|
@ -1,7 +1,11 @@
|
|
|
|
|
import bitsandbytes as bnb
|
|
|
|
|
import os
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 0)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_8bit_linear(model, threshold=6.0):
|
|
|
|
|
"""
|
|
|
|
|
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
|
|
|
@ -29,6 +33,7 @@ def replace_8bit_linear(model, threshold=6.0):
|
|
|
|
|
module.bias is not None,
|
|
|
|
|
has_fp16_weights=False,
|
|
|
|
|
threshold=threshold,
|
|
|
|
|
memory_efficient_backward=PETALS_8BIT_BACKWARD,
|
|
|
|
|
)
|
|
|
|
|
model._modules[n].weight = bnb.nn.Int8Params(
|
|
|
|
|
module.weight.data, requires_grad=False, has_fp16_weights=False
|
|
|
|
|