|
|
|
@ -31,8 +31,6 @@ def replace_8bit_linear(model, threshold=6.0):
|
|
|
|
|
threshold=threshold,
|
|
|
|
|
)
|
|
|
|
|
model._modules[n].weight = bnb.nn.Int8Params(
|
|
|
|
|
module.weight.data,
|
|
|
|
|
requires_grad=False,
|
|
|
|
|
has_fp16_weights=False
|
|
|
|
|
module.weight.data, requires_grad=False, has_fp16_weights=False
|
|
|
|
|
).to(module.weight.dtype)
|
|
|
|
|
return model
|
|
|
|
|