@ -9,11 +9,12 @@ Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#
Exact match tests : see $ REPO / tests / test_linear8bitlt . py
"""
import dataclasses
import warnings
from typing import Optional , Tuple
import bitsandbytes . functional as F
import torch
from bitsandbytes . autograd . _functions import MatMul8bitLt, MatmulLtState
from bitsandbytes . autograd . _functions import GlobalOutlierPooler, MatMul8bitLt, MatmulLtState , prod
from bitsandbytes . nn import Linear8bitLt
@ -88,7 +89,7 @@ class CustomLinear8bitLt(Linear8bitLt):
out = custom_matmul8bitlt ( x , self . weight , bias = self . bias , state = self . state )
if not self . state . has_fp16_weights :
if self . state . CB is not None :
if self . state . CB is not None and self . state . CxB is not None :
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self . state . CB
@ -99,6 +100,7 @@ class CustomLinear8bitLt(Linear8bitLt):
@dataclasses.dataclass ( init = True )
class CustomMatmulLtState ( MatmulLtState ) :
tile_indices : Optional [ torch . Tensor ] = None
force_no_igemmlt : bool = False
def get_tile_size ( self ) :
assert self . formatB in (
@ -123,9 +125,155 @@ def custom_matmul8bitlt(
class CustomMatMul8bitLt ( MatMul8bitLt ) :
# forward is the same as in inference-only CxB
# forward is the same , but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward ( ctx , A , B , out = None , bias = None , state = CustomMatmulLtState ) :
using_igemmlt = torch . cuda . get_device_capability ( device = A . device ) > = ( 7 , 5 ) and not state . force_no_igemmlt
# default to pytorch behavior if inputs are empty
ctx . is_empty = False
if prod ( A . shape ) == 0 :
ctx . is_empty = True
ctx . A = A
ctx . B = B
ctx . bias = bias
if A . shape [ - 1 ] == B . shape [ 0 ] :
return torch . empty ( A . shape [ : - 1 ] + B . shape [ 1 : ] , dtype = A . dtype , device = A . device )
else :
return torch . empty ( A . shape [ : - 1 ] + B . shape [ : 1 ] , dtype = A . dtype , device = A . device )
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state . formatB
input_shape = A . shape
if state . outlier_pool is None :
state . outlier_pool = GlobalOutlierPooler . get_instance ( )
# Cast A to fp16
if A . dtype != torch . float16 :
warnings . warn ( f " MatMul8bitLt: inputs will be cast from { A . dtype } to float16 during quantization " )
# 1. Quantize A
if len ( A . shape ) == 3 :
A = A . view ( - 1 , A . shape [ - 1 ] ) . contiguous ( )
CA , CAt , SCA , SCAt , coo_tensorA = F . double_quant ( A . to ( torch . float16 ) , threshold = state . threshold )
if state . threshold > 0.0 and coo_tensorA is not None :
if state . has_fp16_weights :
idx = torch . unique ( coo_tensorA . colidx ) . long ( )
CA [ : , idx ] = 0
CAt [ : , idx ] = 0
subA = A [ : , idx ]
state . subB = B [ : , idx ] . t ( ) . contiguous ( )
state . idx = idx
else :
if state . CxB is None and using_igemmlt :
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state . CxB , state . SB = F . transform ( state . CB , to_order = formatB )
else :
if not state . has_fp16_weights and state . CxB is None and using_igemmlt :
state . CxB , state . SB = F . transform ( state . CB , to_order = formatB )
subA = None
# 2. Quantize B
if state . has_fp16_weights :
has_grad = True if ( getattr ( B , " grad " , None ) is not None ) else False
is_transposed = not B . is_contiguous ( ) and B . shape [ 0 ] == B . stride ( 1 )
if is_transposed :
B = B . contiguous ( )
if ( state . is_training and not has_grad ) or state . CxB is None :
state . reset_grads ( )
(
CB ,
state . CBt ,
state . SCB ,
state . SCBt ,
coo_tensorB ,
) = F . double_quant ( B . to ( torch . float16 ) )
if using_igemmlt :
state . CxB , state . SB = F . transform ( CB , to_order = formatB )
else :
state . CB = CB
else :
has_grad = False
if coo_tensorA is not None and not state . has_fp16_weights :
# extract outliers
outlier_idx = torch . unique ( coo_tensorA . colidx )
state . idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
if state . CxB is not None :
outliers = F . extract_outliers ( state . CxB , state . SB , state . idx . int ( ) )
else :
outliers = state . CB [ : , state . idx . long ( ) ] . clone ( )
state . subB = ( outliers * state . SCB . view ( - 1 , 1 ) / 127.0 ) . t ( ) . contiguous ( ) . to ( A . dtype )
CA [ : , state . idx . long ( ) ] = 0
CAt [ : , state . idx . long ( ) ] = 0
subA = A [ : , state . idx . long ( ) ]
shapeB = state . SB [ 0 ] if state . SB else B . shape
if len ( input_shape ) == 3 :
output_shape = ( input_shape [ 0 ] , input_shape [ 1 ] , shapeB [ 0 ] )
else :
output_shape = ( input_shape [ 0 ] , shapeB [ 0 ] )
# 3. Matmul
if using_igemmlt :
C32A , SA = F . transform ( CA , " col32 " )
out32 , Sout32 = F . igemmlt ( C32A , state . CxB , SA , state . SB )
if bias is None or bias . dtype == torch . float16 :
# we apply the fused bias here
output = F . mm_dequant ( out32 , Sout32 , SCA , state . SCB , bias = bias )
output = output . to ( A . dtype )
else : # apply bias separately
output = F . mm_dequant ( out32 , Sout32 , SCA , state . SCB , bias = None )
output = output . to ( A . dtype ) . add_ ( bias )
else :
A_wo_outliers = A . clone ( )
if state . idx is not None :
A_wo_outliers [ : , state . idx . long ( ) ] = 0
output = torch . nn . functional . linear ( A_wo_outliers , state . CB . to ( A . dtype ) )
output = output . mul_ ( state . SCB . unsqueeze ( 0 ) . mul ( 1.0 / 127.0 ) )
if bias is not None :
output = output . add_ ( bias )
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None :
output + = torch . matmul ( subA , state . subB )
# 5. Save state
ctx . state = state
ctx . formatB = formatB
ctx . grad_shape = input_shape
ctx . dtype_A , ctx . dtype_B , ctx . dtype_bias = A . dtype , B . dtype , None if bias is None else bias . dtype
if any ( ctx . needs_input_grad [ : 2 ] ) :
ctx . tensors = ( CAt , subA )
ctx . tensor_states = ( SCAt , state . idx )
else :
ctx . tensors = [ None , None ]
ctx . tensor_states = ( None , None )
ctx . save_for_backward ( None , None )
clone_func = torch . clone if len ( output_shape ) == 3 else lambda x : x
return clone_func ( output . view ( output_shape ) )
@staticmethod
def backward ( ctx , grad_output ) :
if ctx . is_empty :