|
|
|
@ -3,14 +3,17 @@ from functools import partial
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import DHT
|
|
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
|
from src import DistributedBloomConfig
|
|
|
|
|
from src.client.remote_model import logger
|
|
|
|
|
from src.data_structures import UID_DELIMITER, RemoteModuleInfo
|
|
|
|
|
from src.dht_utils import _get_remote_module_infos, _create_remote_modules_from_infos
|
|
|
|
|
from hivemind import DHT, use_hivemind_log_handler, get_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteSequential(nn.Sequential):
|
|
|
|
|