black-isort

fix-auth-token
justheuristic 2 years ago
parent 01b9bced78
commit d03b38b9eb

@ -1,3 +1,5 @@
from .bloom import *
from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
from .dht_utils import declare_active_modules, get_remote_module
__version__ = "0.1"

@ -9,8 +9,15 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
class BloomAttention(nn.Module):

@ -11,8 +11,11 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

@ -1,8 +1,8 @@
from typing import NamedTuple, Collection
from typing import Collection, NamedTuple
from hivemind import PeerID
ModuleUID = str
UID_DELIMITER = '.'
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])

@ -12,7 +12,7 @@ from hivemind.p2p import P2P, PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src
from src.data_structures import UID_DELIMITER, ModuleUID, RemoteModuleInfo
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -39,7 +39,7 @@ def declare_active_modules(
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
return_future=not wait,
@ -68,7 +68,7 @@ def get_remote_module(
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
expiration_time: Optional[DHTExpiration] = None,
return_future: bool = False,
) -> Union[List[Optional["src.RemoteTransformerBlock"]], MPFuture[List[Optional["src.RemoteTransformerBlock"]]]]:
) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
"""
:param uid_or_uids: find one or more modules with these ids from across the DHT
:param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)

Loading…
Cancel
Save