mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Fix imports
This commit is contained in:
parent
625d042d0a
commit
59db85174e
@ -140,11 +140,10 @@ Once your have enough servers, you can use them to train and/or inference the mo
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from petals.ansformers
|
||||
from src import DistributedBloomForCausalLM
|
||||
from petals import BloomTokenizerFast, DistributedBloomForCausalLM
|
||||
|
||||
initial_peers = [TODO_put_one_or_more_server_addresses_here] # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
|
||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
|
||||
model = DistributedBloomForCausalLM.from_pretrained(
|
||||
"bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
||||
) # this model has only embeddings / logits, all transformer blocks rely on remote servers
|
||||
|
@ -9,7 +9,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from huggingface_hub import Repository
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from petals.import BloomModel
|
||||
from petals import BloomModel
|
||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
||||
from petals.client import DistributedBloomConfig
|
||||
|
||||
|
@ -7,7 +7,7 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from torch import nn
|
||||
|
||||
import src
|
||||
import petals
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
|
||||
@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht: DHT,
|
||||
dht_prefix: Optional[str] = None,
|
||||
p2p: Optional[P2P] = None,
|
||||
|
@ -12,7 +12,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.p2p import PeerID
|
||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
|
||||
|
||||
import src
|
||||
import petals
|
||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
@ -76,10 +76,10 @@ def get_remote_sequence(
|
||||
dht: DHT,
|
||||
start: int,
|
||||
stop: int,
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
return_future: bool = False,
|
||||
) -> Union[src.RemoteSequential, MPFuture]:
|
||||
) -> Union[petals.RemoteSequential, MPFuture]:
|
||||
return RemoteExpertWorker.run_coroutine(
|
||||
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
|
||||
)
|
||||
@ -89,22 +89,22 @@ async def _get_remote_sequence(
|
||||
dht: DHT,
|
||||
start: int,
|
||||
stop: int,
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
) -> src.RemoteSequential:
|
||||
) -> petals.RemoteSequential:
|
||||
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
|
||||
p2p = await dht.replicate_p2p()
|
||||
manager = src.RemoteSequenceManager(dht, uids, p2p)
|
||||
return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
||||
manager = petals.RemoteSequenceManager(dht, uids, p2p)
|
||||
return petals.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
||||
|
||||
|
||||
def get_remote_module(
|
||||
dht: DHT,
|
||||
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
return_future: bool = False,
|
||||
) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
|
||||
) -> Union[Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]], MPFuture]:
|
||||
"""
|
||||
:param uid_or_uids: find one or more modules with these ids from across the DHT
|
||||
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
|
||||
@ -119,15 +119,15 @@ def get_remote_module(
|
||||
async def _get_remote_module(
|
||||
dht: DHT,
|
||||
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
|
||||
) -> Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]]:
|
||||
single_uid = isinstance(uid_or_uids, ModuleUID)
|
||||
uids = [uid_or_uids] if single_uid else uid_or_uids
|
||||
p2p = await dht.replicate_p2p()
|
||||
managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
||||
managers = (petals.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
||||
modules = [
|
||||
src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
||||
petals.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
||||
]
|
||||
return modules[0] if single_uid else modules
|
||||
|
||||
|
@ -16,7 +16,7 @@ from hivemind.moe.server.runtime import Runtime
|
||||
from hivemind.proto.runtime_pb2 import CompressionType
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
from petals.import BloomConfig, declare_active_modules
|
||||
from petals import BloomConfig, declare_active_modules
|
||||
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
||||
|
@ -11,7 +11,7 @@ from typing import Dict, Union
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
from petals.import project_name
|
||||
from petals import project_name
|
||||
from petals.bloom.block import BloomBlock
|
||||
from petals.bloom.model import BloomConfig
|
||||
from petals.bloom.ops import build_alibi_tensor
|
||||
|
@ -1,6 +1,6 @@
|
||||
from src.bloom import *
|
||||
from src.client import *
|
||||
from src.dht_utils import declare_active_modules, get_remote_module
|
||||
from petals.bloom import *
|
||||
from petals.client import *
|
||||
from petals.dht_utils import declare_active_modules, get_remote_module
|
||||
|
||||
project_name = "bloomd"
|
||||
__version__ = "0.2"
|
||||
|
@ -1,2 +1,2 @@
|
||||
from src.bloom.block import BloomBlock
|
||||
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
|
||||
from petals.bloom.block import BloomBlock
|
||||
from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
|
||||
from src.bloom.ops import (
|
||||
from petals.bloom.ops import (
|
||||
BloomGelu,
|
||||
BloomScaledSoftmax,
|
||||
attention_mask_func,
|
||||
|
@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from transformers.modeling_utils import WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_path, hf_bucket_url
|
||||
|
||||
from src.bloom import BloomBlock, BloomConfig
|
||||
from petals.bloom import BloomBlock, BloomConfig
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
@ -26,7 +26,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
|
||||
from transformers.utils import logging
|
||||
|
||||
from src.bloom.block import BloomBlock
|
||||
from petals.bloom.block import BloomBlock
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = logging.get_logger(__file__)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from src.client.inference_session import InferenceSession
|
||||
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
|
||||
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
|
||||
from src.client.sequence_manager import RemoteSequenceManager
|
||||
from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
|
||||
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
|
||||
from petals.client.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
|
@ -20,10 +20,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.p2p import StubBase
|
||||
from hivemind.proto import runtime_pb2
|
||||
|
||||
from src.client.sequence_manager import RemoteSequenceManager
|
||||
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
||||
from src.server.handler import TransformerConnectionHandler
|
||||
from src.utils.misc import DUMMY, is_dummy
|
||||
from petals.client.sequence_manager import RemoteSequenceManager
|
||||
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
|
||||
from hivemind.utils.streaming import split_for_streaming
|
||||
|
||||
from src.data_structures import ModuleUID, RPCInfo
|
||||
from petals.data_structures import ModuleUID, RPCInfo
|
||||
|
||||
|
||||
async def _forward_unary(
|
||||
|
@ -3,14 +3,14 @@ from typing import List, Optional
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from src.utils.generation_algorithms import (
|
||||
from petals.utils.generation_algorithms import (
|
||||
BeamSearchAlgorithm,
|
||||
DecodingAlgorithm,
|
||||
GreedyAlgorithm,
|
||||
NucleusAlgorithm,
|
||||
TopKAlgorithm,
|
||||
)
|
||||
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
||||
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
from hivemind import get_logger, use_hivemind_log_handler
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
|
||||
from src.bloom.model import (
|
||||
from petals.bloom.model import (
|
||||
BloomConfig,
|
||||
BloomForCausalLM,
|
||||
BloomForSequenceClassification,
|
||||
@ -15,10 +15,10 @@ from src.bloom.model import (
|
||||
BloomPreTrainedModel,
|
||||
LMHead,
|
||||
)
|
||||
from src.client.remote_generation import RemoteGenerationMixin
|
||||
from src.client.remote_sequential import RemoteSequential
|
||||
from src.constants import PUBLIC_INITIAL_PEERS
|
||||
from src.utils.misc import DUMMY
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
@ -7,12 +7,12 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from torch import nn
|
||||
|
||||
import src
|
||||
from src.client.inference_session import InferenceSession
|
||||
from src.client.sequence_manager import RemoteSequenceManager
|
||||
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
|
||||
from src.data_structures import UID_DELIMITER
|
||||
from src.utils.misc import DUMMY
|
||||
import petals
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
|
||||
from petals.data_structures import UID_DELIMITER
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht: DHT,
|
||||
dht_prefix: Optional[str] = None,
|
||||
p2p: Optional[P2P] = None,
|
||||
|
@ -9,10 +9,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
from src.client.spending_policy import NoSpendingPolicy
|
||||
from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
|
||||
from src.dht_utils import get_remote_module_infos
|
||||
from src.server.handler import TransformerConnectionHandler
|
||||
from petals.client.spending_policy import NoSpendingPolicy
|
||||
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
|
||||
from petals.dht_utils import get_remote_module_infos
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
@ -11,11 +11,11 @@ import torch
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
||||
from src.client.sequence_manager import RemoteSequenceManager
|
||||
from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
|
||||
from src.server.handler import TransformerConnectionHandler
|
||||
from src.utils.misc import DUMMY, is_dummy
|
||||
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
||||
from petals.client.sequence_manager import RemoteSequenceManager
|
||||
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -12,8 +12,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.p2p import PeerID
|
||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
|
||||
|
||||
import src
|
||||
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
|
||||
import petals
|
||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
@ -76,10 +76,10 @@ def get_remote_sequence(
|
||||
dht: DHT,
|
||||
start: int,
|
||||
stop: int,
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
return_future: bool = False,
|
||||
) -> Union[src.RemoteSequential, MPFuture]:
|
||||
) -> Union[petals.RemoteSequential, MPFuture]:
|
||||
return RemoteExpertWorker.run_coroutine(
|
||||
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
|
||||
)
|
||||
@ -89,22 +89,22 @@ async def _get_remote_sequence(
|
||||
dht: DHT,
|
||||
start: int,
|
||||
stop: int,
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
) -> src.RemoteSequential:
|
||||
) -> petals.RemoteSequential:
|
||||
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
|
||||
p2p = await dht.replicate_p2p()
|
||||
manager = src.RemoteSequenceManager(dht, uids, p2p)
|
||||
return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
||||
manager = petals.RemoteSequenceManager(dht, uids, p2p)
|
||||
return petals.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
||||
|
||||
|
||||
def get_remote_module(
|
||||
dht: DHT,
|
||||
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
return_future: bool = False,
|
||||
) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
|
||||
) -> Union[Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]], MPFuture]:
|
||||
"""
|
||||
:param uid_or_uids: find one or more modules with these ids from across the DHT
|
||||
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
|
||||
@ -119,15 +119,15 @@ def get_remote_module(
|
||||
async def _get_remote_module(
|
||||
dht: DHT,
|
||||
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
||||
config: src.DistributedBloomConfig,
|
||||
config: petals.DistributedBloomConfig,
|
||||
dht_prefix: Optional[str] = None,
|
||||
) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
|
||||
) -> Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]]:
|
||||
single_uid = isinstance(uid_or_uids, ModuleUID)
|
||||
uids = [uid_or_uids] if single_uid else uid_or_uids
|
||||
p2p = await dht.replicate_p2p()
|
||||
managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
||||
managers = (petals.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
||||
modules = [
|
||||
src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
||||
petals.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
||||
]
|
||||
return modules[0] if single_uid else modules
|
||||
|
||||
|
@ -6,10 +6,10 @@ from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
|
||||
from hivemind.moe.server.module_backend import ModuleBackend
|
||||
from hivemind.utils import get_logger
|
||||
|
||||
from src.bloom.from_pretrained import BloomBlock
|
||||
from src.server.cache import MemoryCache
|
||||
from src.server.task_pool import PrioritizedTaskPool
|
||||
from src.utils.misc import is_dummy
|
||||
from petals.bloom.from_pretrained import BloomBlock
|
||||
from petals.server.cache import MemoryCache
|
||||
from petals.server.task_pool import PrioritizedTaskPool
|
||||
from petals.utils.misc import is_dummy
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
from hivemind import PeerID, get_logger
|
||||
|
||||
from src.data_structures import RemoteModuleInfo, ServerState
|
||||
from petals.data_structures import RemoteModuleInfo, ServerState
|
||||
|
||||
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
|
||||
|
||||
|
@ -21,11 +21,11 @@ from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
|
||||
from hivemind.utils.logging import get_logger
|
||||
from hivemind.utils.streaming import split_for_streaming
|
||||
|
||||
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
||||
from src.server.backend import TransformerBackend
|
||||
from src.server.task_pool import PrioritizedTaskPool
|
||||
from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
||||
from src.utils.misc import DUMMY, is_dummy
|
||||
from petals.data_structures import CHAIN_DELIMITER, ModuleUID
|
||||
from petals.server.backend import TransformerBackend
|
||||
from petals.server.task_pool import PrioritizedTaskPool
|
||||
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -16,17 +16,17 @@ from hivemind.moe.server.runtime import Runtime
|
||||
from hivemind.proto.runtime_pb2 import CompressionType
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
from src import BloomConfig, declare_active_modules
|
||||
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
||||
from src.constants import PUBLIC_INITIAL_PEERS
|
||||
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
||||
from src.dht_utils import get_remote_module_infos
|
||||
from src.server import block_selection
|
||||
from src.server.backend import TransformerBackend
|
||||
from src.server.cache import MemoryCache
|
||||
from src.server.handler import TransformerConnectionHandler
|
||||
from src.server.throughput import get_host_throughput
|
||||
from src.utils.convert_8bit import replace_8bit_linear
|
||||
from petals import BloomConfig, declare_active_modules
|
||||
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
||||
from petals.dht_utils import get_remote_module_infos
|
||||
from petals.server import block_selection
|
||||
from petals.server.backend import TransformerBackend
|
||||
from petals.server.cache import MemoryCache
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
from petals.server.throughput import get_host_throughput
|
||||
from petals.utils.convert_8bit import replace_8bit_linear
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
@ -11,10 +11,10 @@ from typing import Dict, Union
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
from src import project_name
|
||||
from src.bloom.block import BloomBlock
|
||||
from src.bloom.model import BloomConfig
|
||||
from src.bloom.ops import build_alibi_tensor
|
||||
from petals import project_name
|
||||
from petals.bloom.block import BloomBlock
|
||||
from petals.bloom.model import BloomConfig
|
||||
from petals.bloom.ops import build_alibi_tensor
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
@ -7,8 +7,8 @@ import transformers
|
||||
from hivemind import P2PHandlerError
|
||||
from test_utils import *
|
||||
|
||||
import src
|
||||
from petals.import DistributedBloomConfig
|
||||
import petals
|
||||
from petals import DistributedBloomConfig
|
||||
from petals.bloom.from_pretrained import load_pretrained_block
|
||||
from petals.client.remote_sequential import RemoteTransformerBlock
|
||||
from petals.data_structures import UID_DELIMITER
|
||||
|
@ -9,7 +9,7 @@ import pytest
|
||||
import torch
|
||||
from test_utils import *
|
||||
|
||||
import src
|
||||
import petals
|
||||
from petals.bloom.from_pretrained import load_pretrained_block
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.dht_utils import get_remote_sequence
|
||||
@ -18,7 +18,7 @@ from petals.dht_utils import get_remote_sequence
|
||||
@pytest.mark.forked
|
||||
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
|
||||
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
||||
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
||||
config = petals.DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
||||
remote_blocks = get_remote_sequence(dht, 3, 6, config)
|
||||
assert isinstance(remote_blocks, RemoteSequential)
|
||||
|
||||
@ -47,7 +47,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
|
||||
@pytest.mark.forked
|
||||
def test_chained_inference_exact_match(atol_inference=1e-4):
|
||||
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
||||
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
||||
config = petals.DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
||||
remote_blocks = get_remote_sequence(dht, 3, 5, config)
|
||||
assert isinstance(remote_blocks, RemoteSequential)
|
||||
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
||||
from test_utils import *
|
||||
|
||||
from petals.import RemoteSequential
|
||||
from petals import RemoteSequential
|
||||
from petals.bloom.from_pretrained import load_pretrained_block
|
||||
from petals.client.remote_model import DistributedBloomConfig
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user