|
|
|
@ -39,9 +39,10 @@ def test_linear_exact_match():
|
|
|
|
|
threshold=6.0,
|
|
|
|
|
memory_efficient_backward=True,
|
|
|
|
|
)
|
|
|
|
|
linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
|
|
|
|
|
linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
|
|
|
|
|
linear.weight.dtype
|
|
|
|
|
)
|
|
|
|
|
linear8bitlt.bias = linear.bias
|
|
|
|
|
linear8bitlt.cuda()
|
|
|
|
|
|
|
|
|
|
linear_custom = CustomLinear8bitLt(
|
|
|
|
@ -51,10 +52,11 @@ def test_linear_exact_match():
|
|
|
|
|
has_fp16_weights=False,
|
|
|
|
|
threshold=6.0,
|
|
|
|
|
)
|
|
|
|
|
linear_custom.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
|
|
|
|
|
linear.weight.dtype
|
|
|
|
|
)
|
|
|
|
|
linear8bitlt.cuda()
|
|
|
|
|
linear_custom.weight = bnb.nn.Int8Params(
|
|
|
|
|
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
|
|
|
|
).to(linear.weight.dtype)
|
|
|
|
|
linear_custom.bias = linear.bias
|
|
|
|
|
linear_custom.cuda()
|
|
|
|
|
|
|
|
|
|
x_ref = x.clone().cuda().requires_grad_(True)
|
|
|
|
|
x_ours = x.clone().cuda().requires_grad_(True)
|
|
|
|
@ -62,7 +64,10 @@ def test_linear_exact_match():
|
|
|
|
|
grad_proj = torch.randn_like(fx_ref)
|
|
|
|
|
(fx_ref * grad_proj).mean().backward()
|
|
|
|
|
|
|
|
|
|
fx_ours = linear8bitlt(x_ours).float()
|
|
|
|
|
fx_ours = linear_custom(x_ours).float()
|
|
|
|
|
(fx_ours * grad_proj).mean().backward()
|
|
|
|
|
assert torch.equal(fx_ref, fx_ours)
|
|
|
|
|
assert torch.allclose(x_ref.grad, x_ours.grad)
|
|
|
|
|
assert not linear_custom.state.has_fp16_weights
|
|
|
|
|
assert linear_custom.state.CB is None
|
|
|
|
|
assert linear_custom.state.CxB is not None
|
|
|
|
|