@ -27,7 +27,13 @@ class TransformerBackend(ModuleBackend):
_peft_module = None
def __init__ (
self , * args , config : PretrainedConfig , memory_cache : MemoryCache , backend_dtype : torch . dtype , * * kwargs
self ,
* args ,
config : PretrainedConfig ,
memory_cache : MemoryCache ,
backend_dtype : torch . dtype ,
max_chunk_size_bytes : int ,
* * kwargs ,
) :
import petals . utils . peft as _peft_module
@ -37,6 +43,8 @@ class TransformerBackend(ModuleBackend):
assert isinstance ( self . module , TensorParallel )
self . config = config
self . memory_cache = memory_cache
self . max_chunk_size_bytes = max_chunk_size_bytes
for name , param in self . module . named_parameters ( ) :
assert not param . requires_grad , f " Block parameters must not accumulate gradients, but { name } does "
for name , buf in self . module . named_buffers ( ) :
@ -55,6 +63,7 @@ class TransformerBackend(ModuleBackend):
)
self . dtype = backend_dtype
self . dtype_bytes = torch . finfo ( self . dtype ) . bits / / 8
self . shard_num_heads = [ ]
for shard in self . module . module_shards :
for submodule in shard . modules ( ) :
@ -105,14 +114,40 @@ class TransformerBackend(ModuleBackend):
inference_info : InferenceMetadata ,
) - > Tuple [ torch . Tensor , . . . ] :
assert hidden_states . ndim == 3 , " expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size] "
seq_len = hidden_states . shape [ 1 ]
with self . memory_cache . use_cache (
* inference_info . cache_handles
) as cache_tensors , self . _peft_module . using_adapter ( inference_info . active_adapter ) :
self . _reorder_cache_inplace ( cache_tensors , hypo_ids )
# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
# reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`
# is at least 4-6x less than `autograd_memory`.
max_chunk_length = self . _estimate_max_chunk_length ( hidden_states , inference_info )
output_hidden_states = torch . empty_like ( hidden_states ) if seq_len > max_chunk_length else None
layer_past = self . _select_layer_past ( cache_tensors , inference_info . prefix_length )
hidden_states , new_kvs = self . module . forward ( hidden_states , layer_past = layer_past , use_cache = True )
for offset in range ( 0 , seq_len , max_chunk_length ) :
hidden_states_chunk = hidden_states [ : , offset : offset + max_chunk_length , : ]
output_hidden_states_chunk , new_kvs = self . module . forward (
hidden_states_chunk , layer_past = layer_past , use_cache = True
)
if seq_len > max_chunk_length :
output_hidden_states [ : , offset : offset + max_chunk_length ] = output_hidden_states_chunk
else :
output_hidden_states = output_hidden_states_chunk # saves one memcopy
layer_past = new_kvs
self . _update_cache_inplace ( cache_tensors , new_kvs , inference_info . prefix_length )
return ( hidden_states , )
return ( output_hidden_states , )
def _estimate_max_chunk_length ( self , hidden_states : torch . Tensor , inference_info : InferenceMetadata ) - > int :
# We assume that attention logit matrices are the main thing that consumes memory, given that
# the model uses multi-query attention
batch_size , seq_length , hidden_size = hidden_states . shape
worst_case_length = inference_info . prefix_length + seq_length
attn_bytes_per_token = max ( self . shard_num_heads ) * batch_size * self . dtype_bytes * worst_case_length
return max ( 1 , self . max_chunk_size_bytes / / attn_bytes_per_token )
def _reorder_cache_inplace ( self , cache_tensors : torch . Tensor , hypo_ids : torch . Tensor ) :
""" If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids """