black-isort

client
justheuristic 2 years ago
parent 7d68f6b9a4
commit 471e47c0f5

@ -9,15 +9,8 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
class BloomAttention(nn.Module):

@ -11,11 +11,8 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

@ -1,29 +1,22 @@
# this code is in active development, interfaces may change
from typing import Optional
from hivemind import DHT, use_hivemind_log_handler, get_logger
from hivemind import DHT, get_logger, use_hivemind_log_handler
from src.bloom import DistributedBloomConfig, BloomForCausalLM
from src.bloom import BloomForCausalLM, DistributedBloomConfig
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class DistributedBloomForCausalLM(BloomForCausalLM):
"""BloomForCausalLM, but all transformer layers are hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
if prefix.endswith(UID_DELIMITER):
logger.warning(f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
f"This will cause {self.__class__.__name__} to look for modules under "
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended.")
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None):
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
super().__init__(config)
assert len(self.transformer.h) == 0
config.n_layer = n_layer
self.transformer.h = RemoteSequential(config, dht, prefix)

@ -1,16 +1,15 @@
import logging
from functools import partial
from typing import Tuple
from typing import Optional, Tuple
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn
from src import DistributedBloomConfig
from src.data_structures import UID_DELIMITER, RemoteModuleInfo
from src.dht_utils import _get_remote_module_infos, _create_remote_modules_from_infos
from hivemind import DHT, use_hivemind_log_handler, get_logger
from src.dht_utils import _create_remote_modules_from_infos, _get_remote_module_infos
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -18,7 +17,16 @@ logger = get_logger(__file__)
class RemoteSequential(nn.Sequential):
"""A sequence of transformer blocks hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None, max_retries: int = 3):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
if prefix.endswith(UID_DELIMITER):
logger.warning(
f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
f"This will cause {self.__class__.__name__} to look for modules under "
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
)
super().__init__()
self.config = config
self.dht = dht
@ -27,9 +35,12 @@ class RemoteSequential(nn.Sequential):
self.prefix = prefix
self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
logger.debug(f"Remote block uids: {self.block_uids}")
self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float('inf')), return_future=False
))
self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(
dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
return_future=False,
)
)
self.max_retries = max_retries
@ -39,7 +50,7 @@ class RemoteSequential(nn.Sequential):
assert info is not None, f"Found no active peers for block {uid}"
assert isinstance(info.peer_ids, set), f"expected peer_ids to be a set, got {info.peer_ids}"
assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
assert len(info.peer_ids) > 0, f"Found no active peers for block {uid}"
assert len(info.peer_ids) > 0, f"Found no active peers for block {uid}"
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@ -47,7 +58,7 @@ class RemoteSequential(nn.Sequential):
for retry_index in range(self.max_retries):
try:
block = self[block_index]
outputs, = block(inputs)
(outputs,) = block(inputs)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
inputs = outputs
@ -61,7 +72,7 @@ class RemoteSequential(nn.Sequential):
def __getitem__(self, block_index: int):
assert 0 <= block_index < self.config.n_layer
module, = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
(module,) = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
return module
def __iter__(self):

Loading…
Cancel
Save