@ -19,7 +19,7 @@ from hivemind.utils.asyncio import anext
from hivemind . utils . streaming import split_for_streaming
from src . data_structures import CHAIN_DELIMITER , ModuleUID
from src . server . backend import MAX_LENGTH, TransformerBackend
from src . server . backend import TransformerBackend
from src . utils . misc import DUMMY , is_dummy
@ -28,10 +28,11 @@ class TransformerConnectionHandler(ConnectionHandler):
module_backends : Dict [ ModuleUID , TransformerBackend ]
def __init__ ( self , dht : DHT , module_backends : Dict [ str , TransformerBackend ] ):
def __init__ ( self , dht : DHT , module_backends : Dict [ str , TransformerBackend ] , inference_max_length : int ):
super ( ) . __init__ ( dht , module_backends )
for module_backend in self . module_backends . values ( ) :
assert isinstance ( module_backend , TransformerBackend )
self . inference_max_length = inference_max_length
async def rpc_inference (
self ,
@ -43,7 +44,15 @@ class TransformerConnectionHandler(ConnectionHandler):
print ( " OPENED RPC_INFERENCE " )
request = await anext ( requests )
requested_uids = self . _check_uids ( request . uid )
metadata = MSGPackSerializer . loads ( request . metadata ) if request . metadata else { }
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
max_length = metadata . get ( " max_length " )
if not requested_uids :
raise ValueError ( " User must specify at least one block for inference, but got none " )
assert isinstance ( max_length , int ) , f " rpc_inference metadata must contain int max_length, got { max_length } "
if not 0 < = max_length < = self . inference_max_length :
raise ValueError ( f " Cannot allocate KV cache for { max_length } tokens, max = { self . inference_max_length } " )
batch_size = request . tensors [ 0 ] . size [ 0 ] if request . tensors else 1
@ -52,10 +61,17 @@ class TransformerConnectionHandler(ConnectionHandler):
) # [cache_handle, prefix_length]
prefix_length = 0
async with self . _allocate_caches ( requested_backends , batch_size ) as cache_handles :
async with self . _allocate_caches ( requested_backends , batch_size , max_length ) as cache_handles :
assert len ( cache_handles ) == len ( requested_backends )
while request . tensors : # iterate while user is willing to supply tensors
hidden_states = [ deserialize_torch_tensor ( tensor ) for tensor in request . tensors ]
length_increment = hidden_states [ 0 ] . shape [ 1 ] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length :
raise ValueError (
f " Maximum length exceeded: prefix { prefix_length } + current { length_increment } "
f " exceeds pre-allocated maximum { max_length } "
)
# Cast inputs to backend dtype
hidden_states = [ tensor . to ( requested_backends [ 0 ] . dtype ) for tensor in hidden_states ]
@ -113,7 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
hidden_states = await _rpc_forward ( * flat_inputs , requested_backends = requested_backends )
assert isinstance ( hidden_states , torch . Tensor ) and hidden_states . ndim == 3
assert isinstance ( hidden_states , torch . Tensor ) and hidden_states . ndim == 3 , " hidden_states must be a 3d tensor "
# Serialize the overall output
serialized_output = [
@ -193,7 +209,9 @@ class TransformerConnectionHandler(ConnectionHandler):
return tuple ( uids )
@contextlib.asynccontextmanager
async def _allocate_caches ( self , backends : Sequence [ TransformerBackend ] , batch_size : int ) - > Sequence [ int ] :
async def _allocate_caches (
self , backends : Sequence [ TransformerBackend ] , batch_size : int , max_length : int
) - > Sequence [ int ] :
""" Allocate memory caches for each transformer block, return cache handles """
async with contextlib . AsyncExitStack ( ) as stack :
handles = [ ]
@ -202,7 +220,7 @@ class TransformerConnectionHandler(ConnectionHandler):
head_dim = backend . module . self_attention . head_dim
cache_descriptor = TensorDescriptor (
size = ( 2 , batch_size , MAX_LENGTH , num_heads , head_dim ) , dtype = backend . dtype
size = ( 2 , batch_size , max_length , num_heads , head_dim ) , dtype = backend . dtype
)
# [key_or_value, batch_size, max_length, num_heads, head_dim]