@ -12,8 +12,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind . p2p import PeerID
from hivemind . utils import DHTExpiration , MPFuture , get_dht_time , get_logger , use_hivemind_log_handler
import src
from src . data_structures import CHAIN_DELIMITER , UID_DELIMITER , ModuleUID , RemoteModuleInfo , ServerInfo , ServerState
import petals. client
from petal s. data_structures import CHAIN_DELIMITER , UID_DELIMITER , ModuleUID , RemoteModuleInfo , ServerInfo , ServerState
use_hivemind_log_handler ( " in_root_logger " )
logger = get_logger ( __file__ )
@ -76,10 +76,10 @@ def get_remote_sequence(
dht : DHT ,
start : int ,
stop : int ,
config : src . DistributedBloomConfig ,
config : petals. client . DistributedBloomConfig ,
dht_prefix : Optional [ str ] = None ,
return_future : bool = False ,
) - > Union [ src . RemoteSequential , MPFuture ] :
) - > Union [ petals. client . RemoteSequential , MPFuture ] :
return RemoteExpertWorker . run_coroutine (
_get_remote_sequence ( dht , start , stop , config , dht_prefix ) , return_future = return_future
)
@ -89,22 +89,22 @@ async def _get_remote_sequence(
dht : DHT ,
start : int ,
stop : int ,
config : src . DistributedBloomConfig ,
config : petals. client . DistributedBloomConfig ,
dht_prefix : Optional [ str ] = None ,
) - > src . RemoteSequential :
) - > petals. client . RemoteSequential :
uids = [ f " { config . dht_prefix } { UID_DELIMITER } { i } " for i in range ( start , stop ) ]
p2p = await dht . replicate_p2p ( )
manager = src . RemoteSequenceManager ( dht , uids , p2p )
return src . RemoteSequential ( config , dht , dht_prefix , p2p , manager )
manager = petals. client . RemoteSequenceManager ( dht , uids , p2p )
return petals. client . RemoteSequential ( config , dht , dht_prefix , p2p , manager )
def get_remote_module (
dht : DHT ,
uid_or_uids : Union [ ModuleUID , List [ ModuleUID ] ] ,
config : src . DistributedBloomConfig ,
config : petals. client . DistributedBloomConfig ,
dht_prefix : Optional [ str ] = None ,
return_future : bool = False ,
) - > Union [ Union [ src. RemoteTransformerBlock , List [ src . RemoteTransformerBlock ] ] , MPFuture ] :
) - > Union [ Union [ petals. client . RemoteTransformerBlock , List [ petals . client . RemoteTransformerBlock ] ] , MPFuture ] :
"""
: param uid_or_uids : find one or more modules with these ids from across the DHT
: param config : model config , usualy taken by . from_pretrained ( MODEL_NAME )
@ -119,15 +119,16 @@ def get_remote_module(
async def _get_remote_module (
dht : DHT ,
uid_or_uids : Union [ ModuleUID , List [ ModuleUID ] ] ,
config : src . DistributedBloomConfig ,
config : petals. client . DistributedBloomConfig ,
dht_prefix : Optional [ str ] = None ,
) - > Union [ src. RemoteTransformerBlock , List [ src . RemoteTransformerBlock ] ] :
) - > Union [ petals. client . RemoteTransformerBlock , List [ petals . client . RemoteTransformerBlock ] ] :
single_uid = isinstance ( uid_or_uids , ModuleUID )
uids = [ uid_or_uids ] if single_uid else uid_or_uids
p2p = await dht . replicate_p2p ( )
managers = ( src . RemoteSequenceManager ( dht , [ uid ] , p2p ) for uid in uids )
managers = ( petals. client . RemoteSequenceManager ( dht , [ uid ] , p2p ) for uid in uids )
modules = [
src . RemoteTransformerBlock ( config , dht , dht_prefix = dht_prefix , p2p = p2p , sequence_manager = m ) for m in managers
petals . client . RemoteTransformerBlock ( config , dht , dht_prefix = dht_prefix , p2p = p2p , sequence_manager = m )
for m in managers
]
return modules [ 0 ] if single_uid else modules