fix-convert-8bit
dbaranchuk 2 years ago
parent 94f4f25e9b
commit f7555a3e3d

@ -31,8 +31,6 @@ def replace_8bit_linear(model, threshold=6.0):
threshold=threshold, threshold=threshold,
) )
model._modules[n].weight = bnb.nn.Int8Params( model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, module.weight.data, requires_grad=False, has_fp16_weights=False
requires_grad=False,
has_fp16_weights=False
).to(module.weight.dtype) ).to(module.weight.dtype)
return model return model

Loading…
Cancel
Save