@ -68,6 +68,7 @@ class TransformerConnectionHandler(ConnectionHandler):
dht : DHT ,
module_backends : Dict [ str , TransformerBackend ] ,
* ,
adapters : Optional [ Sequence [ str ] ] ,
dht_prefix : str ,
push_manager : multiprocessing . managers . SyncManager ,
session_queues : Dict [ str , multiprocessing . managers . BaseProxy ] , # BaseProxy for queue.Queue
@ -81,6 +82,7 @@ class TransformerConnectionHandler(ConnectionHandler):
for module_backend in self . module_backends . values ( ) :
assert isinstance ( module_backend , TransformerBackend )
self . dht_prefix = dht_prefix
self . adapters = adapters
self . _push_manager = push_manager
self . _session_queues = session_queues
self . _executor = ThreadPoolExecutor ( max_workers = float ( " inf " ) ) # For waiting on self.session_queues
@ -141,7 +143,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer . loads ( request . metadata ) if request . metadata else { }
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
max_length = metadata . get ( " max_length " )
active_adapter = metadata . get ( " active_adapter " , " " )
active_adapter = self . _get_active_adapter ( metadata )
points = metadata . get ( " points " , 0 )
session_id = metadata . get ( " session_id " )
@ -355,7 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
metadata = MSGPackSerializer . loads ( request . metadata ) if request . metadata else { }
active_adapter = metadata . get ( " active_adapter " , " " )
active_adapter = self . _get_active_adapter ( metadata )
points = metadata . get ( " points " , 0 )
assert isinstance (
points , ( float , int )
@ -382,7 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self . _log_request ( " rpc_forward_stream " , requested_uids , context )
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
active_adapter = metadata . get ( " active_adapter " , " " )
active_adapter = self . _get_active_adapter ( metadata )
points = metadata . get ( " points " , 0 )
assert isinstance (
points , ( float , int )
@ -433,7 +435,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
metadata = MSGPackSerializer . loads ( request . metadata ) if request . metadata else { }
active_adapter = metadata . get ( " active_adapter " , " " )
active_adapter = self . _get_active_adapter ( metadata )
points = metadata . get ( " points " , 0 )
assert isinstance (
points , ( float , int )
@ -458,7 +460,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self . _log_request ( " rpc_backward_stream " , requested_uids , context )
requested_backends = tuple ( self . module_backends [ uid ] for uid in requested_uids )
active_adapter = metadata . get ( " active_adapter " , " " )
active_adapter = self . _get_active_adapter ( metadata )
points = metadata . get ( " points " , 0 )
assert isinstance (
points , ( float , int )
@ -476,6 +478,12 @@ class TransformerConnectionHandler(ConnectionHandler):
for part in split_for_streaming ( tensor , DEFAULT_MAX_MSG_SIZE ) :
yield runtime_pb2 . ExpertResponse ( tensors = [ part ] )
def _get_active_adapter ( self , metadata : dict ) - > str :
active_adapter = metadata . get ( " active_adapter " , " " )
if active_adapter and ( active_adapter not in self . adapters ) :
raise KeyError ( f " adapter { active_adapter } not found " )
return active_adapter
def _serialize_grads (
self ,
grads : Sequence [ torch . Tensor ] ,