black-isort

diff
justheuristic 2 years ago
parent 4ad845bce3
commit e32208c954

@ -9,15 +9,8 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
class BloomAttention(nn.Module):

@ -9,11 +9,8 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

@ -11,13 +11,12 @@ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils import anext, nested_flatten, use_hivemind_log_handler, get_logger
from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
from src.data_structures import RemoteModuleInfo
from src.dht_utils import ModuleUID
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -1,20 +1,18 @@
# this code is in active development, interfaces may change
import os
from typing import Optional, Union, Tuple
from typing import Optional, Tuple, Union
import hivemind
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from src.bloom import BloomForYou, DistributedBloomConfig
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
import torch
from hivemind import use_hivemind_log_handler
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -3,10 +3,10 @@ from __future__ import annotations
import dataclasses
import threading
from functools import partial
from typing import Tuple, List, Optional, Sequence, NamedTuple
from typing import List, NamedTuple, Optional, Sequence, Tuple
from hivemind import DHT, PeerID
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.data_structures import ModuleUID, RemoteModuleInfo
from src.dht_utils import _get_remote_module_infos

@ -15,7 +15,6 @@ from src.client.remote_sequence_info import RemoteSequenceInfo
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -5,11 +5,11 @@ from typing import AsyncIterator, Dict, Sequence
import torch
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils import as_aiter
from hivemind.utils.asyncio import anext
from hivemind.utils.streaming import split_for_streaming
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.utils import as_aiter
from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import MAX_LENGTH, TransformerBackend

@ -14,7 +14,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import declare_active_modules
from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
from src.data_structures import UID_DELIMITER, CHAIN_DELIMITER
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
from src.server.backend import TransformerBackend
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler

@ -3,7 +3,7 @@ import os
import torch
import transformers
from hivemind import use_hivemind_log_handler, get_logger
from hivemind import get_logger, use_hivemind_log_handler
from src.client.remote_model import DistributedBloomForCausalLM

Loading…
Cancel
Save