@ -1,8 +1,16 @@
import contextlib
from typing import AsyncIterator , Dict , Sequence
from typing import AsyncIterator , Dict , List, Optional , Sequence, Union
import torch
from hivemind import DHT , P2PContext , TensorDescriptor , deserialize_torch_tensor , nested_flatten , serialize_torch_tensor
from hivemind import (
DHT ,
MSGPackSerializer ,
P2PContext ,
TensorDescriptor ,
deserialize_torch_tensor ,
nested_flatten ,
serialize_torch_tensor ,
)
from hivemind . moe . server . connection_handler import ConnectionHandler
from hivemind . p2p . p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind . proto import runtime_pb2
@ -12,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
from src . data_structures import CHAIN_DELIMITER , ModuleUID
from src . server . backend import MAX_LENGTH , TransformerBackend
from src . utils . misc import DUMMY , is_dummy
class TransformerConnectionHandler ( ConnectionHandler ) :
@ -33,7 +42,7 @@ class TransformerConnectionHandler(ConnectionHandler):
try :
print ( " OPENED RPC_INFERENCE " )
request = await anext ( requests )
requested_uids = self . _check_ header( request )
requested_uids = self . _check_ uids( request . uid )
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
batch_size = request . tensors [ 0 ] . size [ 0 ] if request . tensors else 1
@ -80,27 +89,18 @@ class TransformerConnectionHandler(ConnectionHandler):
async def rpc_forward ( self , request : runtime_pb2 . ExpertRequest , context : P2PContext ) - > runtime_pb2 . ExpertResponse :
# Parse request and prepare backends
hidden_state s = [ deserialize_torch_tensor ( tensor ) for tensor in request . tensors ]
requested_uids = self . _check_ header( request )
flat_input s = [ deserialize_torch_tensor ( tensor ) for tensor in request . tensors ]
requested_uids = self . _check_ uids( request . uid )
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
# Cast inputs to backend dtype
hidden_states = [ tensor . to ( requested_backends [ 0 ] . dtype ) for tensor in hidden_states ]
hidden_states = await _rpc_forward ( * flat_inputs , requested_backends = requested_backends )
assert isinstance ( hidden_states , torch . Tensor ) and hidden_states . ndim == 3
# Run a chain of requested backends
for backend in requested_backends :
assert isinstance ( hidden_states , ( list , tuple ) )
assert (
len ( hidden_states ) == 1 and hidden_states [ 0 ] . ndim == 3
) , f " inputs to { type ( backend ) } must be a list with a single 3d tensor of hidden states "
hidden_states = await backend . forward_pool . submit_task ( * hidden_states )
# Serialize the overall output and respond
assert len ( hidden_states ) == 1 and hidden_states [ 0 ] . ndim == 3
# Serialize output and respond to client
return runtime_pb2 . ExpertResponse (
tensors = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( hidden_states , nested_flatten ( requested_backends [ - 1 ] . outputs_schema ) )
for result , proto in zip ( ( hidden_states , ) , nested_flatten ( requested_backends [ - 1 ] . outputs_schema ) )
]
)
@ -108,29 +108,20 @@ class TransformerConnectionHandler(ConnectionHandler):
self , requests : AsyncIterator [ runtime_pb2 . ExpertRequest ] , context : P2PContext
) - > AsyncIterator [ runtime_pb2 . ExpertRequest ] :
# Parse requests and prepare backends
uid s_header, hidden_state s = await self . _gather_inputs ( requests , context )
requested_uids = self . _check_ header_str( uids_heade r)
uid _str, flat_input s = await self . _gather_inputs ( requests , context )
requested_uids = self . _check_ uids( uid_st r)
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
# Cast inputs to backend dtype
hidden_states = [ tensor . to ( requested_backends [ 0 ] . dtype ) for tensor in hidden_states ]
# Run a chain of requested backends
for backend in requested_backends :
assert isinstance ( hidden_states , ( list , tuple ) )
assert (
len ( hidden_states ) == 1 and hidden_states [ 0 ] . ndim == 3
) , f " inputs to { type ( backend ) } must be a list with a single 3d tensor of hidden states "
hidden_states = await backend . forward_pool . submit_task ( * hidden_states )
hidden_states = await _rpc_forward ( flat_inputs , requested_backends )
assert isinstance ( hidden_states , torch . Tensor ) and hidden_states . ndim == 3
# Serialize the overall output
assert len ( hidden_states ) == 1 and hidden_states [ 0 ] . ndim == 3
serialized_output = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( hidden_states , nested_flatten ( requested_backends [ - 1 ] . outputs_schema ) )
for result , proto in zip ( ( hidden_states , ) , nested_flatten ( requested_backends [ - 1 ] . outputs_schema ) )
]
# Split the serialized_output for streaming and respond
# Split the serialized_output for streaming and respond to client
output_split = [
part for tensor in serialized_output for part in split_for_streaming ( tensor , DEFAULT_MAX_MSG_SIZE )
]
@ -139,36 +130,25 @@ class TransformerConnectionHandler(ConnectionHandler):
async def rpc_backward ( self , request : runtime_pb2 . ExpertRequest , context : P2PContext ) - > runtime_pb2 . ExpertResponse :
# Parse requests and prepare backends
inputs, grad s = [ deserialize_torch_tensor ( tensor ) for tensor in request . tensors ]
requested_uids = self . _check_ header( request )
flat_tensor s = [ deserialize_torch_tensor ( tensor ) for tensor in request . tensors ]
requested_uids = self . _check_ uids( request . uid )
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
# Cast inputs & grad outputs to backend dtype
inputs = inputs . to ( requested_backends [ 0 ] . dtype )
grads = grads . to ( requested_backends [ - 1 ] . dtype )
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = [ inputs ]
for backend in requested_backends [ : - 1 ] :
assert inputs . ndim == 3 , f " inputs to { type ( backend ) } must be a single 3d tensor of hidden states "
inputs = await backend . forward_pool . submit_task ( inputs )
assert isinstance ( inputs , ( list , tuple ) ) and len ( inputs ) == 1
inputs = inputs [ 0 ]
inter_inputs . append ( inputs )
# Run a chain of requested backends
for inp , backend in zip ( inter_inputs [ : : - 1 ] , requested_backends [ : : - 1 ] ) :
inputs_and_grads = [ inp , grads ]
grads = await backend . backward_pool . submit_task ( * inputs_and_grads )
assert isinstance ( grads , ( list , tuple ) ) and len ( grads ) == 1
grads = grads [ 0 ]
grads = await _rpc_backward ( * flat_tensors , requested_backends = requested_backends )
# Modify grad_inputs_schema to support grad_prompts
assert len ( requested_backends [ 0 ] . args_schema ) == 1 and len ( grads ) in ( 1 , 2 ) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends [ 0 ] . args_schema * len ( grads ) ,
requested_backends [ 0 ] . kwargs_schema ,
) # TODO generalize
# Serialize the overall grad_input and respond
return runtime_pb2 . ExpertResponse (
tensors = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( [ grads ] , nested_flatten ( requested_backends[ 0 ] . grad_inputs_schema) )
for result , proto in zip ( grads , nested_flatten ( grad_inputs_schema_with_prompts ) )
]
)
@ -176,36 +156,23 @@ class TransformerConnectionHandler(ConnectionHandler):
self , requests : AsyncIterator [ runtime_pb2 . ExpertRequest ] , context : P2PContext
) - > AsyncIterator [ runtime_pb2 . ExpertResponse ] :
uids_header , inputs_and_grads = await self . _gather_inputs ( requests , context )
inputs , grads = inputs_and_grads
requested_uids = self . _check_header_str ( uids_header )
uids_header , flat_tensors = await self . _gather_inputs ( requests , context )
requested_uids = self . _check_uids ( uids_header )
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
# Cast inputs & grad outputs to backend dtype
inputs = inputs . to ( requested_backends [ 0 ] . dtype )
grads = grads . to ( requested_backends [ - 1 ] . dtype )
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its outputs
inter_inputs = [ inputs ]
for backend in requested_backends [ : - 1 ] :
assert inputs . ndim == 3 , f " inputs to { type ( backend ) } must be a single 3d tensor of hidden states "
inputs = await backend . forward_pool . submit_task ( inputs )
assert isinstance ( inputs , ( list , tuple ) ) and len ( inputs ) == 1
inputs = inputs [ 0 ]
inter_inputs . append ( inputs )
# Run a backward chain for requested backends
for inp , backend in zip ( inter_inputs [ : : - 1 ] , requested_backends [ : : - 1 ] ) :
inputs_and_grads = [ inp , grads ]
grads = await backend . backward_pool . submit_task ( * inputs_and_grads )
assert isinstance ( grads , ( list , tuple ) ) and len ( grads ) == 1
grads = grads [ 0 ]
grads = await _rpc_backward ( * flat_tensors , requested_backends = requested_backends )
# Modify grad_inputs_schema to support grad_prompts
assert len ( requested_backends [ 0 ] . args_schema ) == 1 and len ( grads ) in ( 1 , 2 ) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends [ 0 ] . args_schema * len ( grads ) ,
requested_backends [ 0 ] . kwargs_schema ,
) # TODO generalize
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor ( result . to ( proto . dtype ) , proto . compression , allow_inplace = True )
for result , proto in zip ( [ grads ] , nested_flatten ( requested_backends[ 0 ] . grad_inputs_schema) )
for result , proto in zip ( grads , nested_flatten ( grad_inputs_schema_with_prompts ) )
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
@ -215,19 +182,9 @@ class TransformerConnectionHandler(ConnectionHandler):
async for part in as_aiter ( * output_split ) :
yield runtime_pb2 . ExpertResponse ( tensors = [ part ] )
def _check_ header( self , request : runtime_pb2 . ExpertRequest ) - > Sequence [ ModuleUID ] :
def _check_ uids( self , uids : str ) - > Sequence [ ModuleUID ] :
""" Check that the first request to rpc_inference is valid """
uids = ( request . uid or " " ) . split ( CHAIN_DELIMITER )
if not uids :
raise RuntimeError ( " User did not provide any uids " )
for uid in uids :
if uid not in self . module_backends :
raise RuntimeError ( f " Remote peer does not serve { uid } " )
return tuple ( uids )
def _check_header_str ( self , header ) - > Sequence [ ModuleUID ] :
""" Check that the first request to rpc_inference is valid """
uids = ( header or " " ) . split ( CHAIN_DELIMITER )
uids = ( uids or " " ) . split ( CHAIN_DELIMITER )
if not uids :
raise RuntimeError ( " User did not provide any uids " )
for uid in uids :
@ -252,3 +209,83 @@ class TransformerConnectionHandler(ConnectionHandler):
handles . append ( await stack . enter_async_context ( backend . memory_cache . allocate_cache ( cache_descriptor ) ) )
yield handles
async def _rpc_forward ( * flat_tensors : torch . Tensor , requested_backends : Sequence [ TransformerBackend ] ) - > torch . Tensor :
"""
Run forward pass on deserialized inputs and prompts , used by rpc_forward and rpc_forward_stream
: param flat_tensors : a list of tensors that includes first layer inputs , optional prompts and extra tensors
: note : some input tensors can be missing , in which case they will be replaced with dummy tensors ( see is_dummy )
: param requested_backends : a sequence of transformer blocks in the same order as they appear in forward pass
: returns : hidden states after the last layer [ batch_size , seq_length , hid_size ]
"""
hidden_states , * prompts = flat_tensors
dtype = requested_backends [ 0 ] . dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states . to ( dtype )
assert hidden_states . ndim == 3
if not prompts or is_dummy ( prompts [ 0 ] ) :
prompts = [ DUMMY ] * len ( requested_backends )
pre_seq_len = 0
else :
prompts = [ prompts [ 0 ] . to ( requested_backends [ 0 ] . dtype ) ]
prompts = [ p . squeeze ( 0 ) for p in prompts [ 0 ] . split ( 1 ) ]
pre_seq_len = prompts [ 0 ] . shape [ - 2 ]
# Run a chain of requested backends
for backend , prompt in zip ( requested_backends , prompts ) :
if not is_dummy ( prompt ) :
hidden_states [ : , : pre_seq_len ] + = prompt
( hidden_states , ) = await backend . forward_pool . submit_task ( hidden_states )
assert isinstance ( hidden_states , torch . Tensor )
assert (
hidden_states . ndim == 3
) , f " inputs to { type ( backend ) } must be a list with a single 3d tensor of hidden states "
# Serialize the overall output
return hidden_states
async def _rpc_backward (
* flat_tensors : torch . Tensor , requested_backends : Sequence [ TransformerBackend ]
) - > Union [ torch . Tensor , Sequence [ torch . Tensor ] ] :
inputs , grad_outputs , * prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs . to ( requested_backends [ 0 ] . dtype )
grad_outputs = grad_outputs . to ( requested_backends [ - 1 ] . dtype )
if not prompts or is_dummy ( prompts [ 0 ] ) :
prompts = [ DUMMY ] * len ( requested_backends )
pre_seq_len = 0
else :
prompts = [ prompts [ 0 ] . to ( requested_backends [ 0 ] . dtype ) ]
prompts = [ p . squeeze ( 0 ) for p in prompts [ 0 ] . split ( 1 ) ]
pre_seq_len = prompts [ 0 ] . shape [ - 2 ]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = [ ]
for backend , prompt in zip ( requested_backends [ : - 1 ] , prompts [ : - 1 ] ) :
assert inputs . ndim == 3 , f " inputs to { type ( backend ) } must be a single 3d tensor of hidden states "
if not is_dummy ( prompt ) :
inputs [ : , : pre_seq_len ] + = prompt
inter_inputs . append ( inputs )
( inputs , ) = await backend . forward_pool . submit_task ( inputs )
assert isinstance ( inputs , torch . Tensor )
if not is_dummy ( prompts [ - 1 ] ) :
inputs [ : , : pre_seq_len ] + = prompts [ - 1 ]
inter_inputs . append ( inputs )
assert len ( inter_inputs ) == len ( prompts ) == len ( requested_backends ) , " internal shape error during backward "
grad_prompts_reversed = [ ]
# Run a chain of requested backends
for inp , prompt , backend in zip ( * map ( reversed , ( inter_inputs , prompts , requested_backends ) ) ) :
( grad_outputs , ) = await backend . backward_pool . submit_task ( inp , grad_outputs )
assert isinstance ( grad_outputs , torch . Tensor )
if not is_dummy ( prompt ) :
grad_prompts_reversed . append ( grad_outputs [ : , : pre_seq_len ] . unsqueeze ( 0 ) )
grad_prompts = torch . cat ( grad_prompts_reversed [ : : - 1 ] , dim = 0 ) if grad_prompts_reversed else DUMMY
return [ grad_outputs ] if is_dummy ( grad_prompts ) else [ grad_outputs , grad_prompts ] # TODO un-duct-tape