black-isort

fix-auth-token
justheuristic 2 years ago
parent 01b9bced78
commit d03b38b9eb

@ -1,3 +1,5 @@
from .bloom import * from .bloom import *
from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
from .dht_utils import declare_active_modules, get_remote_module from .dht_utils import declare_active_modules, get_remote_module
__version__ = "0.1"

@ -9,8 +9,15 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add, from src.bloom.ops import (
pre_process_alibi_for_pad, split_tensor_along_last_dim) BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
class BloomAttention(nn.Module): class BloomAttention(nn.Module):

@ -11,8 +11,11 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler from hivemind import use_hivemind_log_handler
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings, from transformers.file_utils import (
add_start_docstrings_to_model_forward) add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

@ -1,8 +1,8 @@
from typing import NamedTuple, Collection from typing import Collection, NamedTuple
from hivemind import PeerID from hivemind import PeerID
ModuleUID = str ModuleUID = str
UID_DELIMITER = '.' UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])]) RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])

@ -12,7 +12,7 @@ from hivemind.p2p import P2P, PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src import src
from src.data_structures import UID_DELIMITER, ModuleUID, RemoteModuleInfo from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
@ -39,7 +39,7 @@ def declare_active_modules(
if not isinstance(uids, list): if not isinstance(uids, list):
uids = list(uids) uids = list(uids)
for uid in uids: for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine( return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput), partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
return_future=not wait, return_future=not wait,
@ -68,7 +68,7 @@ def get_remote_module(
uid_or_uids: Union[ModuleUID, List[ModuleUID]], uid_or_uids: Union[ModuleUID, List[ModuleUID]],
expiration_time: Optional[DHTExpiration] = None, expiration_time: Optional[DHTExpiration] = None,
return_future: bool = False, return_future: bool = False,
) -> Union[List[Optional["src.RemoteTransformerBlock"]], MPFuture[List[Optional["src.RemoteTransformerBlock"]]]]: ) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
""" """
:param uid_or_uids: find one or more modules with these ids from across the DHT :param uid_or_uids: find one or more modules with these ids from across the DHT
:param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time) :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)

Loading…
Cancel
Save