Fix Linear8bitlt state config, update tests (#112)

* fix state initializer
* update tests to actually use new code
* keep bias during quantization
pull/114/head
justheuristic 2 years ago committed by GitHub
parent 96033de921
commit 01838f9a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,3 @@
import os
import bitsandbytes as bnb
import torch
@ -37,4 +35,5 @@ def replace_8bit_linear(model, threshold=6.0):
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False
).to(module.weight.dtype)
model._modules[n].bias = module.bias
return model

@ -70,7 +70,12 @@ class CustomLinear8bitLt(Linear8bitLt):
def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs):
assert not memory_efficient_backward, "memory_efficient_backward is no longer used"
super().__init__(*args, **kwargs)
self.state = CustomMatmulLtState(**dataclasses.asdict(self.state))
old_state, self.state = self.state, CustomMatmulLtState()
self.state.threshold = old_state.threshold
self.state.has_fp16_weights = old_state.has_fp16_weights
self.state.memory_efficient_backward = old_state.memory_efficient_backward
if old_state.threshold > 0.0 and not old_state.has_fp16_weights:
self.state.use_pool = True
def forward(self, x: torch.Tensor):
self.state.is_training = self.training

@ -39,9 +39,10 @@ def test_linear_exact_match():
threshold=6.0,
memory_efficient_backward=True,
)
linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
linear.weight.dtype
)
linear8bitlt.bias = linear.bias
linear8bitlt.cuda()
linear_custom = CustomLinear8bitLt(
@ -51,10 +52,11 @@ def test_linear_exact_match():
has_fp16_weights=False,
threshold=6.0,
)
linear_custom.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
linear.weight.dtype
)
linear8bitlt.cuda()
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear_custom.cuda()
x_ref = x.clone().cuda().requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True)
@ -62,7 +64,10 @@ def test_linear_exact_match():
grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward()
fx_ours = linear8bitlt(x_ours).float()
fx_ours = linear_custom(x_ours).float()
(fx_ours * grad_proj).mean().backward()
assert torch.equal(fx_ref, fx_ours)
assert torch.allclose(x_ref.grad, x_ours.grad)
assert not linear_custom.state.has_fp16_weights
assert linear_custom.state.CB is None
assert linear_custom.state.CxB is not None

Loading…
Cancel
Save