Add LLaMA support (#323)
This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).pull/329/head
parent
5c0733711a
commit
cb3f018f9f
@ -1,62 +0,0 @@
|
|||||||
"""
|
|
||||||
Bloom intermediate layer
|
|
||||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
||||||
See commit history for authorship.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch.nn.quantized.dynamic.modules.linear
|
|
||||||
import transformers
|
|
||||||
from packaging import version
|
|
||||||
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
|
|
||||||
|
|
||||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
|
||||||
assert (
|
|
||||||
version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
|
||||||
), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
class WrappedBloomBlock(BloomBlock):
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
*args,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
alibi: Optional[torch.Tensor] = None,
|
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
assert attention_mask is None
|
|
||||||
batch_size, seq_length = hidden_states.shape[:2]
|
|
||||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
|
||||||
seq_length_with_past = seq_length + past_length
|
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
||||||
if alibi is None:
|
|
||||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
|
||||||
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
|
||||||
return super().forward(
|
|
||||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_attn_mask(
|
|
||||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
|
||||||
) -> torch.BoolTensor:
|
|
||||||
# create causal mask
|
|
||||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
|
||||||
combined_attention_mask = None
|
|
||||||
device = attention_mask.device
|
|
||||||
_, src_length = input_shape
|
|
||||||
|
|
||||||
if src_length > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
|
||||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
||||||
combined_attention_mask = (
|
|
||||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_attention_mask
|
|
@ -1,132 +0,0 @@
|
|||||||
"""
|
|
||||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
|
||||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
|
||||||
- loading files from a local data source (e.g. S3)
|
|
||||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
|
||||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import time
|
|
||||||
from typing import Optional, OrderedDict, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from accelerate.utils import set_module_tensor_to_device
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from transformers.modeling_utils import WEIGHTS_NAME
|
|
||||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
||||||
from transformers.utils import get_file_from_repo
|
|
||||||
|
|
||||||
from petals.bloom.block import WrappedBloomBlock
|
|
||||||
from petals.server.block_utils import get_block_size, resolve_block_dtype
|
|
||||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
CLIENT_BRANCH = "main"
|
|
||||||
BLOCK_BRANCH_PREFIX = "block_"
|
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_block(
|
|
||||||
converted_model_name_or_path: str,
|
|
||||||
block_index: int,
|
|
||||||
config: Optional[BloomConfig] = None,
|
|
||||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
|
||||||
use_auth_token: Optional[str] = None,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
max_disk_space: Optional[int] = None,
|
|
||||||
) -> WrappedBloomBlock:
|
|
||||||
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
||||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
||||||
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = DEFAULT_CACHE_DIR
|
|
||||||
|
|
||||||
with init_empty_weights():
|
|
||||||
block = WrappedBloomBlock(config)
|
|
||||||
|
|
||||||
state_dict = _load_state_dict(
|
|
||||||
converted_model_name_or_path,
|
|
||||||
block_index,
|
|
||||||
config,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
max_disk_space=max_disk_space,
|
|
||||||
)
|
|
||||||
|
|
||||||
# dummy load, check that keys match
|
|
||||||
report = block.load_state_dict(state_dict, strict=True)
|
|
||||||
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
|
||||||
|
|
||||||
for param_name, _ in block.named_parameters():
|
|
||||||
assert param_name in state_dict, f"{param_name} not in state dict"
|
|
||||||
param = state_dict[param_name]
|
|
||||||
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
|
||||||
param = param.to(torch_dtype)
|
|
||||||
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
|
||||||
|
|
||||||
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
|
||||||
return block
|
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict(
|
|
||||||
pretrained_model_name_or_path: str,
|
|
||||||
block_index: int,
|
|
||||||
config: BloomConfig,
|
|
||||||
*,
|
|
||||||
use_auth_token: Optional[str] = None,
|
|
||||||
cache_dir: str,
|
|
||||||
max_disk_space: Optional[int] = None,
|
|
||||||
min_backoff: float = 5,
|
|
||||||
) -> OrderedDict[str, torch.Tensor]:
|
|
||||||
revision = BLOCK_BRANCH_PREFIX + str(block_index)
|
|
||||||
|
|
||||||
# First, try to find the weights locally
|
|
||||||
try:
|
|
||||||
with allow_cache_reads(cache_dir):
|
|
||||||
archive_file = get_file_from_repo(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=WEIGHTS_NAME,
|
|
||||||
revision=revision,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
if archive_file is not None:
|
|
||||||
return torch.load(archive_file, map_location="cpu")
|
|
||||||
except Exception:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
|
||||||
for attempt_no in itertools.count():
|
|
||||||
try:
|
|
||||||
with allow_cache_writes(cache_dir):
|
|
||||||
block_size = get_block_size(config, "disk")
|
|
||||||
free_disk_space_for(
|
|
||||||
pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
|
|
||||||
)
|
|
||||||
|
|
||||||
archive_file = get_file_from_repo(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=WEIGHTS_NAME,
|
|
||||||
revision=revision,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
return torch.load(archive_file, map_location="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
delay = min_backoff * (2**attempt_no)
|
|
||||||
logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
||||||
time.sleep(delay)
|
|
||||||
|
|
||||||
|
|
||||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"apply_residual_connection_post_layernorm": false,
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"attention_softmax_in_fp32": true,
|
|
||||||
"bos_token_id": 1,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"hidden_dropout": 0.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"layer_norm_epsilon": 1e-05,
|
|
||||||
"masked_softmax_fusion": true,
|
|
||||||
"model_type": "bloom",
|
|
||||||
"n_embed": 14336,
|
|
||||||
"n_layer": 70,
|
|
||||||
"num_attention_heads": 112,
|
|
||||||
"pretraining_tp": 4,
|
|
||||||
"slow_but_exact": false,
|
|
||||||
"transformers_version": "4.20.0.dev0",
|
|
||||||
"use_cache": true,
|
|
||||||
"vocab_size": 250880
|
|
||||||
}
|
|
@ -1,96 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch.backends.quantized
|
|
||||||
import torch.nn as nn
|
|
||||||
import transformers
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from huggingface_hub import HfApi, Repository
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP
|
|
||||||
from petals.client import DistributedBloomConfig
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
|
||||||
|
|
||||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
|
||||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
|
||||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
|
||||||
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
|
||||||
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
|
||||||
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
|
||||||
parser.add_argument(
|
|
||||||
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
|
||||||
)
|
|
||||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
|
||||||
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
|
||||||
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
|
||||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
|
||||||
|
|
||||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
|
||||||
if os.path.exists(args.output_path) and (
|
|
||||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
|
||||||
):
|
|
||||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
|
||||||
|
|
||||||
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
|
||||||
config = DistributedBloomConfig.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
||||||
)
|
|
||||||
config.dht_prefix = args.output_repo
|
|
||||||
|
|
||||||
model = BloomModel.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
||||||
)
|
|
||||||
if args.resize_token_embeddings:
|
|
||||||
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
|
||||||
model.resize_token_embeddings(args.resize_token_embeddings)
|
|
||||||
config.vocab_size = args.resize_token_embeddings
|
|
||||||
|
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
||||||
)
|
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
|
||||||
|
|
||||||
api = HfApi(token=args.use_auth_token)
|
|
||||||
api.create_repo(args.output_repo, repo_type="model", exist_ok=True)
|
|
||||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
||||||
repo.git_pull()
|
|
||||||
|
|
||||||
transformer_blocks = model.h
|
|
||||||
logger.info(
|
|
||||||
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
|
||||||
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
|
||||||
)
|
|
||||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
|
||||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
||||||
with repo.commit(
|
|
||||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
|
||||||
):
|
|
||||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
|
||||||
|
|
||||||
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
|
||||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
||||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
|
||||||
model.h = nn.ModuleList()
|
|
||||||
model.save_pretrained(".")
|
|
||||||
tokenizer.save_pretrained(".")
|
|
||||||
config.save_pretrained(".")
|
|
||||||
|
|
||||||
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,10 +1,4 @@
|
|||||||
from petals.client.inference_session import InferenceSession
|
from petals.client.inference_session import InferenceSession
|
||||||
from petals.client.remote_model import (
|
|
||||||
DistributedBloomConfig,
|
|
||||||
DistributedBloomForCausalLM,
|
|
||||||
DistributedBloomForSequenceClassification,
|
|
||||||
DistributedBloomModel,
|
|
||||||
)
|
|
||||||
from petals.client.remote_sequential import RemoteSequential
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||||
|
@ -0,0 +1,94 @@
|
|||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers import BloomPreTrainedModel, modeling_utils
|
||||||
|
|
||||||
|
from petals.utils.version import get_compatible_model_repo
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FromPretrainedMixin:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_name_or_path: Union[str, os.PathLike, None],
|
||||||
|
*args,
|
||||||
|
low_cpu_mem_usage: Optional[bool] = None,
|
||||||
|
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_name_or_path = get_compatible_model_repo(model_name_or_path)
|
||||||
|
if low_cpu_mem_usage is None:
|
||||||
|
low_cpu_mem_usage = True
|
||||||
|
if torch_dtype is None:
|
||||||
|
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
||||||
|
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
||||||
|
torch_dtype = "auto"
|
||||||
|
|
||||||
|
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
|
||||||
|
return super().from_pretrained(
|
||||||
|
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||||
|
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||||
|
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||||
|
).replace(
|
||||||
|
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||||
|
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_shard_config = threading.local()
|
||||||
|
_shard_config.ignored_keys = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def ignore_keys(patterns: List[str]):
|
||||||
|
try:
|
||||||
|
prev_patterns = _shard_config.ignored_keys
|
||||||
|
_shard_config.ignored_keys = patterns
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_shard_config.ignored_keys = prev_patterns
|
||||||
|
|
||||||
|
|
||||||
|
def patched_get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||||
|
) -> Tuple[List[str], dict]:
|
||||||
|
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
|
||||||
|
|
||||||
|
should_ignore_keys = _shard_config.ignored_keys is not None
|
||||||
|
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
|
||||||
|
with tempdir_ctx as tempdir:
|
||||||
|
if should_ignore_keys:
|
||||||
|
with open(index_filename) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
n_original_shards = len(set(index["weight_map"].values()))
|
||||||
|
|
||||||
|
index["weight_map"] = {
|
||||||
|
param_name: filename
|
||||||
|
for param_name, filename in index["weight_map"].items()
|
||||||
|
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
|
||||||
|
}
|
||||||
|
n_loaded_shards = len(set(index["weight_map"].values()))
|
||||||
|
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
|
||||||
|
|
||||||
|
# Replace the original index with a patched JSON, where ignored keys are removed
|
||||||
|
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
|
||||||
|
with open(index_filename, "w") as f:
|
||||||
|
json.dump(index, f)
|
||||||
|
|
||||||
|
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
|
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files
|
@ -0,0 +1,88 @@
|
|||||||
|
import dataclasses
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from petals.utils.misc import DUMMY
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PTuneConfig:
|
||||||
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||||
|
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||||
|
|
||||||
|
|
||||||
|
class PTuneMixin:
|
||||||
|
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
|
||||||
|
|
||||||
|
def init_prompts(self, config: PretrainedConfig) -> None:
|
||||||
|
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||||
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||||
|
self.pre_seq_len = config.pre_seq_len
|
||||||
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||||
|
|
||||||
|
with force_non_empty_weights():
|
||||||
|
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
|
||||||
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||||
|
if config.tuning_mode == "deep_ptune":
|
||||||
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||||
|
self.pre_seq_len,
|
||||||
|
config.num_hidden_layers * config.hidden_size,
|
||||||
|
# ^-- TODO: should be num_hidden_layers - 1
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
elif config.tuning_mode:
|
||||||
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||||
|
|
||||||
|
def set_requires_grad(self, value):
|
||||||
|
for p in self.parameters():
|
||||||
|
p.requires_grad = value
|
||||||
|
|
||||||
|
def get_prompt(self, batch_size):
|
||||||
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||||
|
prompts = self.prompt_embeddings(prefix_tokens)
|
||||||
|
|
||||||
|
if self.config.tuning_mode == "deep_ptune":
|
||||||
|
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||||
|
intermediate_prompts = intermediate_prompts.view(
|
||||||
|
batch_size,
|
||||||
|
self.pre_seq_len,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.hidden_size
|
||||||
|
# TODO: should be num_hidden_layers - 1
|
||||||
|
)
|
||||||
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||||
|
else:
|
||||||
|
intermediate_prompts = DUMMY
|
||||||
|
|
||||||
|
dtype = self.word_embeddings.weight.dtype
|
||||||
|
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
_original_register_parameter = nn.Module.register_parameter
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def force_non_empty_weights():
|
||||||
|
"""
|
||||||
|
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||||
|
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||||
|
The transformers library should replace all meta tensors by empty tensors by itself
|
||||||
|
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||||
|
|
||||||
|
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||||
|
nn.Module.register_parameter = _original_register_parameter
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
nn.Module.register_parameter = possibly_patched_register_parameter
|
@ -1,268 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import hivemind
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
||||||
from transformers.models.bloom import (
|
|
||||||
BloomConfig,
|
|
||||||
BloomForCausalLM,
|
|
||||||
BloomForSequenceClassification,
|
|
||||||
BloomModel,
|
|
||||||
BloomPreTrainedModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
from petals.bloom.modeling_utils import LMHead
|
|
||||||
from petals.client.remote_generation import RemoteGenerationMixin
|
|
||||||
from petals.client.remote_sequential import RemoteSequential
|
|
||||||
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
|
||||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
||||||
from petals.utils.misc import DUMMY
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
|
|
||||||
"""
|
|
||||||
A bloom config that contains information about DHT peers.
|
|
||||||
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
|
||||||
"""
|
|
||||||
|
|
||||||
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
|
||||||
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
||||||
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
|
||||||
|
|
||||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
||||||
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
|
||||||
|
|
||||||
# This settings matter for running the client with dtype bfloat16 on CPU.
|
|
||||||
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
|
||||||
use_chunked_forward: Union[str, bool] = "auto"
|
|
||||||
chunked_forward_step: int = 16384
|
|
||||||
|
|
||||||
|
|
||||||
original_register_parameter = nn.Module.register_parameter
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def force_non_empty_weights():
|
|
||||||
"""
|
|
||||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
|
||||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
|
||||||
The transformers library should replace all meta tensors by empty tensors by itself
|
|
||||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
|
||||||
|
|
||||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
|
||||||
nn.Module.register_parameter = original_register_parameter
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
|
||||||
|
|
||||||
|
|
||||||
class _FromPretrainedDefaultsMixin:
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(
|
|
||||||
cls,
|
|
||||||
*args,
|
|
||||||
low_cpu_mem_usage: Optional[bool] = None,
|
|
||||||
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if low_cpu_mem_usage is None:
|
|
||||||
low_cpu_mem_usage = True
|
|
||||||
if torch_dtype is None:
|
|
||||||
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
|
||||||
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
|
||||||
torch_dtype = "auto"
|
|
||||||
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
|
|
||||||
|
|
||||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
|
||||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
|
||||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
|
||||||
).replace(
|
|
||||||
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
|
||||||
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
|
|
||||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
|
||||||
r"^(intermediate_)?prompt_embeddings\.weight$",
|
|
||||||
]
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
|
||||||
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
|
||||||
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
|
|
||||||
|
|
||||||
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
|
||||||
super().__init__(config)
|
|
||||||
assert len(self.h) == 0
|
|
||||||
config.n_layer = n_layer
|
|
||||||
|
|
||||||
self.h = RemoteSequential(config, dht=dht)
|
|
||||||
|
|
||||||
# Forbid accumulate grads for embeddings and layernorm
|
|
||||||
self.set_requires_grad(False)
|
|
||||||
|
|
||||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
||||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
||||||
self.pre_seq_len = config.pre_seq_len
|
|
||||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
||||||
|
|
||||||
with force_non_empty_weights():
|
|
||||||
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
|
||||||
logger.info(
|
|
||||||
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
|
||||||
"to increase ptune quality"
|
|
||||||
)
|
|
||||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
|
||||||
if config.tuning_mode == "deep_ptune":
|
|
||||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
||||||
self.pre_seq_len,
|
|
||||||
config.num_hidden_layers * config.hidden_size,
|
|
||||||
# ^-- TODO: should be num_hidden_layers - 1
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
elif config.tuning_mode:
|
|
||||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
||||||
|
|
||||||
def set_requires_grad(self, value):
|
|
||||||
for p in self.parameters():
|
|
||||||
p.requires_grad = value
|
|
||||||
|
|
||||||
def get_prompt(self, batch_size):
|
|
||||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
||||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
||||||
prompts = self.prompt_embeddings(prefix_tokens)
|
|
||||||
|
|
||||||
if self.config.tuning_mode == "deep_ptune":
|
|
||||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
|
||||||
intermediate_prompts = intermediate_prompts.view(
|
|
||||||
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
|
||||||
)
|
|
||||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
||||||
else:
|
|
||||||
intermediate_prompts = DUMMY
|
|
||||||
|
|
||||||
dtype = self.word_embeddings.weight.dtype
|
|
||||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if not (v is None or v is False):
|
|
||||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
||||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
||||||
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
|
||||||
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
|
||||||
else:
|
|
||||||
hidden_states = self.h(hidden_states)
|
|
||||||
|
|
||||||
# Remove prefix
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
||||||
|
|
||||||
# Add last hidden state
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
hidden_states = hidden_states.view(output_shape)
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=None,
|
|
||||||
hidden_states=None,
|
|
||||||
attentions=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
|
|
||||||
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = (
|
|
||||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
BloomPreTrainedModel.__init__(self, config)
|
|
||||||
self.transformer = DistributedBloomModel(config)
|
|
||||||
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.transformer.word_embeddings
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
return None
|
|
||||||
return self.lm_head
|
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
|
||||||
assert isinstance(new_embeddings, nn.Embedding)
|
|
||||||
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
|
||||||
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
|
||||||
|
|
||||||
def set_output_embeddings(self, new_lm_head: nn.Linear):
|
|
||||||
with torch.no_grad():
|
|
||||||
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
||||||
self.lm_head.bias[...] = new_lm_head.bias
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
|
|
||||||
_keys_to_ignore_on_load_missing = (
|
|
||||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
)
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
BloomPreTrainedModel.__init__(self, config)
|
|
||||||
self.num_labels = config.num_labels
|
|
||||||
|
|
||||||
self.transformer = DistributedBloomModel(config)
|
|
||||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
@ -0,0 +1,2 @@
|
|||||||
|
from petals.models.bloom import *
|
||||||
|
from petals.models.llama import *
|
@ -0,0 +1,7 @@
|
|||||||
|
from petals.models.bloom.block import WrappedBloomBlock
|
||||||
|
from petals.models.bloom.config import DistributedBloomConfig
|
||||||
|
from petals.models.bloom.model import (
|
||||||
|
DistributedBloomForCausalLM,
|
||||||
|
DistributedBloomForSequenceClassification,
|
||||||
|
DistributedBloomModel,
|
||||||
|
)
|
@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Bloom intermediate layer
|
||||||
|
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||||
|
See commit history for authorship.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedBloomBlock(BloomBlock):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
alibi: Optional[torch.Tensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert attention_mask is None, "Non-causal attention masks are not supported yet"
|
||||||
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
|
if alibi is None:
|
||||||
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||||
|
)
|
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.bloom import BloomConfig
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||||
|
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||||
|
from petals.models.bloom.block import WrappedBloomBlock
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from petals.utils.version import get_compatible_model_repo
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedBloomBlock
|
||||||
|
attn_class = BloomAttention
|
||||||
|
block_prefix = "h"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
# We need "-petals" for backward compatibility with Petals < 1.2.0
|
||||||
|
dht_prefix = str(model_name_or_path) + "-petals"
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
AutoDistributedConfig.register(DistributedBloomConfig)
|
@ -0,0 +1,134 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.bloom.config import DistributedBloomConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
||||||
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = (
|
||||||
|
BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
)
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.h) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.h = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not (v is None or v is False):
|
||||||
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
|
||||||
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||||
|
else:
|
||||||
|
hidden_states = self.h(hidden_states)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||||
|
_keys_to_ignore_on_load_missing = (
|
||||||
|
BloomForCausalLM._keys_to_ignore_on_load_missing
|
||||||
|
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
|
+ [r"^lm_head\."] # Missing since they are shared with input embeddings
|
||||||
|
)
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig):
|
||||||
|
BloomPreTrainedModel.__init__(self, config)
|
||||||
|
self.transformer = DistributedBloomModel(config)
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
||||||
|
_keys_to_ignore_on_load_missing = (
|
||||||
|
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
||||||
|
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
|
)
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig):
|
||||||
|
BloomPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.transformer = DistributedBloomModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
@ -0,0 +1,7 @@
|
|||||||
|
from petals.models.llama.block import WrappedLlamaBlock
|
||||||
|
from petals.models.llama.config import DistributedLlamaConfig
|
||||||
|
from petals.models.llama.model import (
|
||||||
|
DistributedLlamaForCausalLM,
|
||||||
|
DistributedLlamaForSequenceClassification,
|
||||||
|
DistributedLlamaModel,
|
||||||
|
)
|
@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
LLaMA intermediate layer
|
||||||
|
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
See commit history for authorship.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedLlamaBlock(LlamaDecoderLayer):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
past_key_value = layer_past
|
||||||
|
if past_key_value is not None:
|
||||||
|
past_key_values_length = past_key_value[0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = hidden_states.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||||
|
)
|
||||||
|
attention_mask = LlamaModel._prepare_decoder_attention_mask(
|
||||||
|
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = super().forward(
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
present_key_value = outputs[-1]
|
||||||
|
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||||
|
present_key_value, batch_size, seq_length_with_past
|
||||||
|
)
|
||||||
|
outputs = outputs[:-1] + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _reorder_cache_from_bloom_to_llama(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
|
||||||
|
value_states = value_states.view(*key_states.shape)
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _reorder_cache_from_llama_to_bloom(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
|
||||||
|
key_states = key_states.view(*value_states.shape)
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
return (key_states, value_states)
|
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||||
|
from petals.models.llama.block import WrappedLlamaBlock
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedLlamaBlock
|
||||||
|
attn_class = LlamaAttention
|
||||||
|
block_prefix = "model.layers"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
dht_prefix = str(model_name_or_path)
|
||||||
|
if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts
|
||||||
|
dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
AutoDistributedConfig.register(DistributedLlamaConfig)
|
@ -0,0 +1,152 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.llama.config import DistributedLlamaConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
||||||
|
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."]
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.layers) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.layers = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseModelOutputWithPast:
|
||||||
|
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not (v is None or v is False):
|
||||||
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
|
||||||
|
else:
|
||||||
|
hidden_states = self.layers(hidden_states)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.norm
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedLlamaConfig):
|
||||||
|
LlamaPreTrainedModel.__init__(self, config)
|
||||||
|
self.model = DistributedLlamaModel(config)
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
||||||
|
_keys_to_ignore_on_load_missing = (
|
||||||
|
LlamaForSequenceClassification._keys_to_ignore_on_load_missing
|
||||||
|
+ DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||||
|
)
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
LlamaPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.model = DistributedLlamaModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||||
|
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||||
|
- loading files from a local data source (e.g. S3)
|
||||||
|
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||||
|
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||||
|
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.utils import get_file_from_repo
|
||||||
|
|
||||||
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretrained_block(
|
||||||
|
model_name: str,
|
||||||
|
block_index: int,
|
||||||
|
*,
|
||||||
|
config: Optional[PretrainedConfig] = None,
|
||||||
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
use_auth_token: Optional[str] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
if config is None:
|
||||||
|
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = DEFAULT_CACHE_DIR
|
||||||
|
|
||||||
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||||
|
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
block = config.block_class(config)
|
||||||
|
|
||||||
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
||||||
|
state_dict = _load_state_dict_from_repo(
|
||||||
|
model_name,
|
||||||
|
block_prefix,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dummy load, check that keys match
|
||||||
|
report = block.load_state_dict(state_dict, strict=True)
|
||||||
|
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||||
|
|
||||||
|
for param_name, _ in block.named_parameters():
|
||||||
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||||
|
param = state_dict[param_name]
|
||||||
|
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||||
|
param = param.to(torch_dtype)
|
||||||
|
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {model_name} block {block_index}, {report}")
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
|
StateDict = Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_from_repo(
|
||||||
|
model_name: str,
|
||||||
|
block_prefix: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
use_auth_token: Optional[str] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
) -> StateDict:
|
||||||
|
index_file = get_file_from_repo(
|
||||||
|
model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
if index_file is not None: # Sharded model
|
||||||
|
with open(index_file) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
filenames = {
|
||||||
|
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
|
||||||
|
}
|
||||||
|
if not filenames:
|
||||||
|
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
|
||||||
|
else: # Non-sharded model
|
||||||
|
filenames = {"pytorch_model.bin"}
|
||||||
|
logger.debug(f"Loading {block_prefix}* from {filenames}")
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
for filename in filenames:
|
||||||
|
shard_state_dict = _load_state_dict_from_file(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
|
)
|
||||||
|
shard_state_dict = {
|
||||||
|
param_name[len(block_prefix) :]: param
|
||||||
|
for param_name, param in shard_state_dict.items()
|
||||||
|
if param_name.startswith(block_prefix)
|
||||||
|
} # Remove unused parameters from memory
|
||||||
|
state_dict.update(shard_state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_from_file(
|
||||||
|
model_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
use_auth_token: Optional[str] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
delay: float = 30,
|
||||||
|
) -> StateDict:
|
||||||
|
# First, try to find the weights locally
|
||||||
|
try:
|
||||||
|
with allow_cache_reads(cache_dir):
|
||||||
|
path = get_file_from_repo(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
if path is not None:
|
||||||
|
return torch.load(path, map_location="cpu")
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
|
||||||
|
|
||||||
|
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with allow_cache_writes(cache_dir):
|
||||||
|
url = hf_hub_url(model_name, filename, revision=revision)
|
||||||
|
file_size = get_hf_file_metadata(url, token=use_auth_token).size
|
||||||
|
if file_size is not None:
|
||||||
|
free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
|
||||||
|
|
||||||
|
path = get_file_from_repo(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
if path is None:
|
||||||
|
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
|
||||||
|
return torch.load(path, map_location="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
@ -0,0 +1 @@
|
|||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register()
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedConfig:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
|
||||||
|
config = AutoConfig.from_pretrained(*args, **kwargs)
|
||||||
|
if config.model_type not in CONFIG_MAPPING:
|
||||||
|
raise ValueError(f"Petals does not support model type {config.model_type}")
|
||||||
|
|
||||||
|
dist_config_class = CONFIG_MAPPING[config.model_type]
|
||||||
|
return dist_config_class.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(config_class: Type[PretrainedConfig]) -> None:
|
||||||
|
assert issubclass(config_class, PretrainedConfig)
|
||||||
|
assert config_class.model_type not in CONFIG_MAPPING
|
||||||
|
|
||||||
|
CONFIG_MAPPING[config_class.model_type] = config_class
|
@ -1,17 +1,16 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import load_pretrained_block
|
|
||||||
from petals.client import DistributedBloomConfig
|
|
||||||
from petals.server.block_utils import resolve_block_dtype
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
from test_utils import MODEL_NAME
|
from test_utils import MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
|
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
|
||||||
def test_backend_dtype(torch_dtype):
|
def test_block_dtype(torch_dtype):
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||||
block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype)
|
block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
|
||||||
backend_dtype = resolve_block_dtype(config, torch_dtype)
|
expected_dtype = resolve_block_dtype(config, torch_dtype)
|
||||||
other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
|
assert all(param.dtype == expected_dtype for param in block.parameters())
|
||||||
assert backend_dtype == other_backend_dtype
|
|
||||||
|
Loading…
Reference in New Issue