black-isort

8bit_blocks
justheuristic 2 years ago
parent 2bf83b42e5
commit 83cd4412a1

@ -3,10 +3,10 @@ import os
import psutil
import torch.backends.quantized
import torch.nn as nn
import transformers
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from huggingface_hub import Repository
import torch.nn as nn
from tqdm.auto import tqdm
use_hivemind_log_handler("in_root_logger")
@ -85,4 +85,3 @@ if __name__ == "__main__":
config.save_pretrained(".")
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")

@ -1 +1 @@
from src.bloom.model import BloomForCausalLM, BloomModel, DistributedBloomConfig, BloomBlock
from src.bloom.model import BloomBlock, BloomForCausalLM, BloomModel, DistributedBloomConfig

@ -9,15 +9,8 @@ import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
pre_process_alibi_for_pad, split_tensor_along_last_dim)
class BloomAttention(nn.Module):

@ -11,18 +11,18 @@ from __future__ import annotations
from typing import Optional, OrderedDict, Union
import torch
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
from transformers.utils.hub import hf_bucket_url, cached_path
from src.bloom import BloomForCausalLM, DistributedBloomConfig, BloomBlock
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url
from src.bloom import BloomBlock, BloomForCausalLM, DistributedBloomConfig
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
CLIENT_BRANCH = "client"
BLOCK_BRANCH_PREFIX = "block_"
USER_AGENT = {'file_type': 'model', 'framework': 'pytorch', 'from_auto_class': False}
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
cls = BloomForCausalLM
FORCE_DOWNLOAD = False
RESUME_DOWNLOAD = False
@ -30,8 +30,11 @@ LOCAL_FILES_ONLY = False
def load_pretrained_block(
converted_model_name_or_path: str, block_index: int,
config: Optional[DistributedBloomConfig] = None, torch_dtype: Union[torch.dtype, str] = 'auto') -> BloomBlock:
converted_model_name_or_path: str,
block_index: int,
config: Optional[DistributedBloomConfig] = None,
torch_dtype: Union[torch.dtype, str] = "auto",
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
if config is None:
config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
@ -39,7 +42,7 @@ def load_pretrained_block(
state_dict = _load_state_dict(converted_model_name_or_path, block_index)
block.load_state_dict(state_dict)
if torch_dtype == 'auto':
if torch_dtype == "auto":
with torch.no_grad():
for name, param in block.named_parameters():
assert name in state_dict, f"{name} not in state dict"
@ -54,7 +57,8 @@ def load_pretrained_block(
def _load_state_dict(
pretrained_model_name_or_path: str, block_index: Optional[int] = None) -> OrderedDict[str, torch.Tensor]:
pretrained_model_name_or_path: str, block_index: Optional[int] = None
) -> OrderedDict[str, torch.Tensor]:
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
@ -69,7 +73,7 @@ def _load_state_dict(
use_auth_token=True,
user_agent=USER_AGENT,
)
state_dict = torch.load(resolved_archive_file, map_location='cpu')
state_dict = torch.load(resolved_archive_file, map_location="cpu")
return state_dict

@ -11,11 +11,8 @@ import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

@ -7,8 +7,8 @@ import math
import torch
import torch.autograd
from torch import nn
import torch.nn.functional as F
from torch import nn
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):

@ -35,7 +35,7 @@ class TransformerBackend(ModuleBackend):
print("METADATA:", cache_metadata)
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print('PAST', past_k.shape, past_v.shape)
print("PAST", past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
# todo remove these asserts once we pass all tests

@ -12,12 +12,11 @@ from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.bloom.from_pretrained import load_pretrained_block, DistributedBloomConfig, DTYPE_MAP
from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
from src.server.backend import TransformerBackend
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -89,7 +88,7 @@ class Server(threading.Thread):
num_handlers: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: int = 4096,
torch_dtype: str = 'auto',
torch_dtype: str = "auto",
cache_size_bytes: Optional[int] = None,
device: Union[str, torch.device] = None,
initial_peers: Sequence[str] = (),

Loading…
Cancel
Save