diff --git a/src/petals/utils/convert_8bit.py b/src/petals/utils/convert_8bit.py index e6125ac..eeb29e7 100644 --- a/src/petals/utils/convert_8bit.py +++ b/src/petals/utils/convert_8bit.py @@ -1,5 +1,3 @@ -import os - import bitsandbytes as bnb import torch @@ -37,4 +35,5 @@ def replace_8bit_linear(model, threshold=6.0): model._modules[n].weight = bnb.nn.Int8Params( module.weight.data, requires_grad=False, has_fp16_weights=False ).to(module.weight.dtype) + model._modules[n].bias = module.bias return model diff --git a/src/petals/utils/linear8bitlt_patch.py b/src/petals/utils/linear8bitlt_patch.py index 1a5064f..435bd9f 100644 --- a/src/petals/utils/linear8bitlt_patch.py +++ b/src/petals/utils/linear8bitlt_patch.py @@ -70,7 +70,12 @@ class CustomLinear8bitLt(Linear8bitLt): def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs): assert not memory_efficient_backward, "memory_efficient_backward is no longer used" super().__init__(*args, **kwargs) - self.state = CustomMatmulLtState(**dataclasses.asdict(self.state)) + old_state, self.state = self.state, CustomMatmulLtState() + self.state.threshold = old_state.threshold + self.state.has_fp16_weights = old_state.has_fp16_weights + self.state.memory_efficient_backward = old_state.memory_efficient_backward + if old_state.threshold > 0.0 and not old_state.has_fp16_weights: + self.state.use_pool = True def forward(self, x: torch.Tensor): self.state.is_training = self.training diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 9c1457e..8b4fe7f 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -39,9 +39,10 @@ def test_linear_exact_match(): threshold=6.0, memory_efficient_backward=True, ) - linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to( + linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to( linear.weight.dtype ) + linear8bitlt.bias = linear.bias linear8bitlt.cuda() linear_custom = CustomLinear8bitLt( @@ -51,10 +52,11 @@ def test_linear_exact_match(): has_fp16_weights=False, threshold=6.0, ) - linear_custom.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to( - linear.weight.dtype - ) - linear8bitlt.cuda() + linear_custom.weight = bnb.nn.Int8Params( + linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(linear.weight.dtype) + linear_custom.bias = linear.bias + linear_custom.cuda() x_ref = x.clone().cuda().requires_grad_(True) x_ours = x.clone().cuda().requires_grad_(True) @@ -62,7 +64,10 @@ def test_linear_exact_match(): grad_proj = torch.randn_like(fx_ref) (fx_ref * grad_proj).mean().backward() - fx_ours = linear8bitlt(x_ours).float() + fx_ours = linear_custom(x_ours).float() (fx_ours * grad_proj).mean().backward() assert torch.equal(fx_ref, fx_ours) assert torch.allclose(x_ref.grad, x_ours.grad) + assert not linear_custom.state.has_fp16_weights + assert linear_custom.state.CB is None + assert linear_custom.state.CxB is not None