8bit_blocks
justheuristic 2 years ago
parent ed468af8d6
commit 3b9351de1c

@ -1,13 +1,12 @@
import argparse
import torch
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tqdm.auto import trange
from src.bloom.model import DistributedBloomConfig
from src.bloom.block import BloomBlock
from src.bloom.model import DistributedBloomConfig
from src.bloom.ops import build_alibi_tensor
from tqdm.auto import trange
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -1,5 +1,4 @@
import configargparse
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

@ -1 +1 @@
from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig
from src.bloom.model import BloomForCausalLM, BloomModel, DistributedBloomConfig

@ -9,15 +9,8 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
build_alibi_tensor,
)
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
class BloomAttention(nn.Module):

@ -11,11 +11,8 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

@ -1,5 +1,5 @@
from typing import Sequence
from collections import defaultdict
from typing import Sequence
import torch
from hivemind import DHT

@ -1,16 +1,18 @@
from __future__ import annotations
import asyncio
from functools import partial
from typing import List, Optional, Union, Sequence, AsyncIterator, Dict, Any
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Union
import torch
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertUID, ExpertInfo as RemoteModuleInfo
from hivemind.moe.expert_uid import ExpertInfo as RemoteModuleInfo
from hivemind.moe.expert_uid import ExpertUID
from hivemind.p2p import P2P, PeerID, StubBase
from hivemind.proto import runtime_pb2
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.utils import MPFuture, DHTExpiration, get_dht_time, as_aiter, anext, nested_flatten
from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils import DHTExpiration, MPFuture, anext, as_aiter, get_dht_time, nested_flatten
from src.server.handler import TransformerConnectionHandler

@ -1,5 +1,5 @@
"""Code for serving bloom blocks via hivemind-server"""
from typing import Tuple, Sequence
from typing import Sequence, Tuple
import torch
from hivemind.moe.server.module_backend import ModuleBackend

@ -1,12 +1,12 @@
from typing import AsyncIterator, Dict
import torch
from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor, nested_flatten
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import anext
from src.server.backend import TransformerBackend, MAX_LENGTH
from src.server.backend import MAX_LENGTH, TransformerBackend
class TransformerConnectionHandler(ConnectionHandler):

@ -1,6 +1,8 @@
from __future__ import annotations
import multiprocessing as mp
import threading
from typing import Optional, Dict, Union, Sequence
from typing import Dict, Optional, Sequence, Union
import torch
from hivemind import DHT, BatchTensorDescriptor
@ -8,13 +10,12 @@ from hivemind.moe.server.dht_handler import DHTHandlerThread
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
import multiprocessing as mp
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import DistributedBloomConfig, BloomForCausalLM
from src import BloomForCausalLM, DistributedBloomConfig
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
from src.server.backend import TransformerBackend
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")

Loading…
Cancel
Save