Fix imports

This commit is contained in:
Aleksandr Borzunov 2022-11-30 03:36:27 +00:00
parent 625d042d0a
commit 59db85174e
28 changed files with 106 additions and 107 deletions

View File

@ -140,11 +140,10 @@ Once your have enough servers, you can use them to train and/or inference the mo
```python
import torch
import torch.nn.functional as F
from petals.ansformers
from src import DistributedBloomForCausalLM
from petals import BloomTokenizerFast, DistributedBloomForCausalLM
initial_peers = [TODO_put_one_or_more_server_addresses_here] # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
tokenizer = BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
model = DistributedBloomForCausalLM.from_pretrained(
"bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
) # this model has only embeddings / logits, all transformer blocks rely on remote servers

View File

@ -9,7 +9,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from huggingface_hub import Repository
from tqdm.auto import tqdm
from petals.import BloomModel
from petals import BloomModel
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from petals.client import DistributedBloomConfig

View File

@ -7,7 +7,7 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn
import src
import petals
from petals.client.inference_session import InferenceSession
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
def __init__(
self,
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht: DHT,
dht_prefix: Optional[str] = None,
p2p: Optional[P2P] = None,

View File

@ -12,7 +12,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
use_hivemind_log_handler("in_root_logger")
@ -76,10 +76,10 @@ def get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False,
) -> Union[src.RemoteSequential, MPFuture]:
) -> Union[petals.RemoteSequential, MPFuture]:
return RemoteExpertWorker.run_coroutine(
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
)
@ -89,22 +89,22 @@ async def _get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> src.RemoteSequential:
) -> petals.RemoteSequential:
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
p2p = await dht.replicate_p2p()
manager = src.RemoteSequenceManager(dht, uids, p2p)
return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
manager = petals.RemoteSequenceManager(dht, uids, p2p)
return petals.RemoteSequential(config, dht, dht_prefix, p2p, manager)
def get_remote_module(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False,
) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
) -> Union[Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]], MPFuture]:
"""
:param uid_or_uids: find one or more modules with these ids from across the DHT
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@ -119,15 +119,15 @@ def get_remote_module(
async def _get_remote_module(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
) -> Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]]:
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
p2p = await dht.replicate_p2p()
managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
managers = (petals.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
modules = [
src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
petals.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
]
return modules[0] if single_uid else modules

View File

@ -16,7 +16,7 @@ from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from petals.import BloomConfig, declare_active_modules
from petals import BloomConfig, declare_active_modules
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState

View File

@ -11,7 +11,7 @@ from typing import Dict, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from petals.import project_name
from petals import project_name
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor

View File

@ -1,6 +1,6 @@
from src.bloom import *
from src.client import *
from src.dht_utils import declare_active_modules, get_remote_module
from petals.bloom import *
from petals.client import *
from petals.dht_utils import declare_active_modules, get_remote_module
project_name = "bloomd"
__version__ = "0.2"

View File

@ -1,2 +1,2 @@
from src.bloom.block import BloomBlock
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (
from petals.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,

View File

@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url
from src.bloom import BloomBlock, BloomConfig
from petals.bloom import BloomBlock, BloomConfig
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

View File

@ -26,7 +26,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
from transformers.utils import logging
from src.bloom.block import BloomBlock
from petals.bloom.block import BloomBlock
use_hivemind_log_handler("in_root_logger")
logger = logging.get_logger(__file__)

View File

@ -1,5 +1,5 @@
from src.client.inference_session import InferenceSession
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager
from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
from petals.client.inference_session import InferenceSession
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

View File

@ -20,10 +20,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
from petals.client.sequence_manager import RemoteSequenceManager
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)

View File

@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import ModuleUID, RPCInfo
from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary(

View File

@ -3,14 +3,14 @@ from typing import List, Optional
import torch
from hivemind.utils.logging import get_logger
from src.utils.generation_algorithms import (
from petals.utils.generation_algorithms import (
BeamSearchAlgorithm,
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
TopKAlgorithm,
)
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
logger = get_logger(__file__)

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from hivemind import get_logger, use_hivemind_log_handler
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from src.bloom.model import (
from petals.bloom.model import (
BloomConfig,
BloomForCausalLM,
BloomForSequenceClassification,
@ -15,10 +15,10 @@ from src.bloom.model import (
BloomPreTrainedModel,
LMHead,
)
from src.client.remote_generation import RemoteGenerationMixin
from src.client.remote_sequential import RemoteSequential
from src.constants import PUBLIC_INITIAL_PEERS
from src.utils.misc import DUMMY
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_sequential import RemoteSequential
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

View File

@ -7,12 +7,12 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn
import src
from src.client.inference_session import InferenceSession
from src.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER
from src.utils.misc import DUMMY
import petals
from petals.client.inference_session import InferenceSession
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
def __init__(
self,
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht: DHT,
dht_prefix: Optional[str] = None,
p2p: Optional[P2P] = None,

View File

@ -9,10 +9,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.client.spending_policy import NoSpendingPolicy
from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from src.dht_utils import get_remote_module_infos
from src.server.handler import TransformerConnectionHandler
from petals.client.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.dht_utils import get_remote_module_infos
from petals.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

View File

@ -11,11 +11,11 @@ import torch
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.sequence_manager import RemoteSequenceManager
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)

View File

@ -12,8 +12,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -76,10 +76,10 @@ def get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False,
) -> Union[src.RemoteSequential, MPFuture]:
) -> Union[petals.RemoteSequential, MPFuture]:
return RemoteExpertWorker.run_coroutine(
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
)
@ -89,22 +89,22 @@ async def _get_remote_sequence(
dht: DHT,
start: int,
stop: int,
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> src.RemoteSequential:
) -> petals.RemoteSequential:
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
p2p = await dht.replicate_p2p()
manager = src.RemoteSequenceManager(dht, uids, p2p)
return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
manager = petals.RemoteSequenceManager(dht, uids, p2p)
return petals.RemoteSequential(config, dht, dht_prefix, p2p, manager)
def get_remote_module(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
return_future: bool = False,
) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
) -> Union[Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]], MPFuture]:
"""
:param uid_or_uids: find one or more modules with these ids from across the DHT
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@ -119,15 +119,15 @@ def get_remote_module(
async def _get_remote_module(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig,
config: petals.DistributedBloomConfig,
dht_prefix: Optional[str] = None,
) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
) -> Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]]:
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
p2p = await dht.replicate_p2p()
managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
managers = (petals.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
modules = [
src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
petals.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
]
return modules[0] if single_uid else modules

View File

@ -6,10 +6,10 @@ from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from src.bloom.from_pretrained import BloomBlock
from src.server.cache import MemoryCache
from src.server.task_pool import PrioritizedTaskPool
from src.utils.misc import is_dummy
from petals.bloom.from_pretrained import BloomBlock
from petals.server.cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
from hivemind import PeerID, get_logger
from src.data_structures import RemoteModuleInfo, ServerState
from petals.data_structures import RemoteModuleInfo, ServerState
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]

View File

@ -21,11 +21,11 @@ from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import TransformerBackend
from src.server.task_pool import PrioritizedTaskPool
from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from src.utils.misc import DUMMY, is_dummy
from petals.data_structures import CHAIN_DELIMITER, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)

View File

@ -16,17 +16,17 @@ from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import BloomConfig, declare_active_modules
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from src.constants import PUBLIC_INITIAL_PEERS
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from src.dht_utils import get_remote_module_infos
from src.server import block_selection
from src.server.backend import TransformerBackend
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
from src.server.throughput import get_host_throughput
from src.utils.convert_8bit import replace_8bit_linear
from petals import BloomConfig, declare_active_modules
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.dht_utils import get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend
from petals.server.cache import MemoryCache
from petals.server.handler import TransformerConnectionHandler
from petals.server.throughput import get_host_throughput
from petals.utils.convert_8bit import replace_8bit_linear
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

View File

@ -11,10 +11,10 @@ from typing import Dict, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import project_name
from src.bloom.block import BloomBlock
from src.bloom.model import BloomConfig
from src.bloom.ops import build_alibi_tensor
from petals import project_name
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

View File

@ -7,8 +7,8 @@ import transformers
from hivemind import P2PHandlerError
from test_utils import *
import src
from petals.import DistributedBloomConfig
import petals
from petals import DistributedBloomConfig
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client.remote_sequential import RemoteTransformerBlock
from petals.data_structures import UID_DELIMITER

View File

@ -9,7 +9,7 @@ import pytest
import torch
from test_utils import *
import src
import petals
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client.remote_sequential import RemoteSequential
from petals.dht_utils import get_remote_sequence
@ -18,7 +18,7 @@ from petals.dht_utils import get_remote_sequence
@pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
config = petals.DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_blocks = get_remote_sequence(dht, 3, 6, config)
assert isinstance(remote_blocks, RemoteSequential)
@ -47,7 +47,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
config = petals.DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_blocks = get_remote_sequence(dht, 3, 5, config)
assert isinstance(remote_blocks, RemoteSequential)

View File

@ -3,7 +3,7 @@ import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from test_utils import *
from petals.import RemoteSequential
from petals import RemoteSequential
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client.remote_model import DistributedBloomConfig