Add LLaMA support (#323)
This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).pull/329/head
parent
5c0733711a
commit
cb3f018f9f
@ -1,62 +0,0 @@
|
||||
"""
|
||||
Bloom intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
import transformers
|
||||
from packaging import version
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
|
||||
|
||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
||||
assert (
|
||||
version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
||||
), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
|
||||
|
||||
|
||||
class WrappedBloomBlock(BloomBlock):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs
|
||||
):
|
||||
assert attention_mask is None
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
return super().forward(
|
||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||
)
|
||||
|
||||
def _prepare_attn_mask(
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if src_length > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
@ -1,132 +0,0 @@
|
||||
"""
|
||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||
- loading files from a local data source (e.g. S3)
|
||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import time
|
||||
from typing import Optional, OrderedDict, Union
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_utils import WEIGHTS_NAME
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.bloom.block import WrappedBloomBlock
|
||||
from petals.server.block_utils import get_block_size, resolve_block_dtype
|
||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
CLIENT_BRANCH = "main"
|
||||
BLOCK_BRANCH_PREFIX = "block_"
|
||||
|
||||
|
||||
def load_pretrained_block(
|
||||
converted_model_name_or_path: str,
|
||||
block_index: int,
|
||||
config: Optional[BloomConfig] = None,
|
||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> WrappedBloomBlock:
|
||||
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
|
||||
if config is None:
|
||||
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
||||
if cache_dir is None:
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
|
||||
with init_empty_weights():
|
||||
block = WrappedBloomBlock(config)
|
||||
|
||||
state_dict = _load_state_dict(
|
||||
converted_model_name_or_path,
|
||||
block_index,
|
||||
config,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
|
||||
# dummy load, check that keys match
|
||||
report = block.load_state_dict(state_dict, strict=True)
|
||||
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||
|
||||
for param_name, _ in block.named_parameters():
|
||||
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||
param = state_dict[param_name]
|
||||
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
param = param.to(torch_dtype)
|
||||
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||
|
||||
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
||||
return block
|
||||
|
||||
|
||||
def _load_state_dict(
|
||||
pretrained_model_name_or_path: str,
|
||||
block_index: int,
|
||||
config: BloomConfig,
|
||||
*,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
min_backoff: float = 5,
|
||||
) -> OrderedDict[str, torch.Tensor]:
|
||||
revision = BLOCK_BRANCH_PREFIX + str(block_index)
|
||||
|
||||
# First, try to find the weights locally
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
archive_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
if archive_file is not None:
|
||||
return torch.load(archive_file, map_location="cpu")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
|
||||
)
|
||||
|
||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||
for attempt_no in itertools.count():
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
block_size = get_block_size(config, "disk")
|
||||
free_disk_space_for(
|
||||
pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
|
||||
)
|
||||
|
||||
archive_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
return torch.load(archive_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
delay = min_backoff * (2**attempt_no)
|
||||
logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"apply_residual_connection_post_layernorm": false,
|
||||
"attention_dropout": 0.0,
|
||||
"attention_softmax_in_fp32": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_dropout": 0.0,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"masked_softmax_fusion": true,
|
||||
"model_type": "bloom",
|
||||
"n_embed": 14336,
|
||||
"n_layer": 70,
|
||||
"num_attention_heads": 112,
|
||||
"pretraining_tp": 4,
|
||||
"slow_but_exact": false,
|
||||
"transformers_version": "4.20.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 250880
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import psutil
|
||||
import torch.backends.quantized
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import HfApi, Repository
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||
|
||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP
|
||||
from petals.client import DistributedBloomConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
||||
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
||||
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
||||
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
||||
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
||||
parser.add_argument(
|
||||
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
||||
)
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
|
||||
args = parser.parse_args()
|
||||
|
||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
||||
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
||||
|
||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
if os.path.exists(args.output_path) and (
|
||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
||||
):
|
||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
||||
|
||||
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
||||
config = DistributedBloomConfig.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
config.dht_prefix = args.output_repo
|
||||
|
||||
model = BloomModel.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
if args.resize_token_embeddings:
|
||||
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
||||
model.resize_token_embeddings(args.resize_token_embeddings)
|
||||
config.vocab_size = args.resize_token_embeddings
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
api = HfApi(token=args.use_auth_token)
|
||||
api.create_repo(args.output_repo, repo_type="model", exist_ok=True)
|
||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
||||
repo.git_pull()
|
||||
|
||||
transformer_blocks = model.h
|
||||
logger.info(
|
||||
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
||||
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
||||
)
|
||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(
|
||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
||||
):
|
||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
||||
|
||||
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
||||
model.h = nn.ModuleList()
|
||||
model.save_pretrained(".")
|
||||
tokenizer.save_pretrained(".")
|
||||
config.save_pretrained(".")
|
||||
|
||||
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,10 +1,4 @@
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_model import (
|
||||
DistributedBloomConfig,
|
||||
DistributedBloomForCausalLM,
|
||||
DistributedBloomForSequenceClassification,
|
||||
DistributedBloomModel,
|
||||
)
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
|
@ -0,0 +1,94 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers import BloomPreTrainedModel, modeling_utils
|
||||
|
||||
from petals.utils.version import get_compatible_model_repo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FromPretrainedMixin:
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name_or_path: Union[str, os.PathLike, None],
|
||||
*args,
|
||||
low_cpu_mem_usage: Optional[bool] = None,
|
||||
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model_name_or_path = get_compatible_model_repo(model_name_or_path)
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
if torch_dtype is None:
|
||||
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
||||
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
||||
torch_dtype = "auto"
|
||||
|
||||
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
|
||||
return super().from_pretrained(
|
||||
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
|
||||
)
|
||||
|
||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||
).replace(
|
||||
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||
)
|
||||
|
||||
|
||||
_shard_config = threading.local()
|
||||
_shard_config.ignored_keys = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ignore_keys(patterns: List[str]):
|
||||
try:
|
||||
prev_patterns = _shard_config.ignored_keys
|
||||
_shard_config.ignored_keys = patterns
|
||||
yield
|
||||
finally:
|
||||
_shard_config.ignored_keys = prev_patterns
|
||||
|
||||
|
||||
def patched_get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||
) -> Tuple[List[str], dict]:
|
||||
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
|
||||
|
||||
should_ignore_keys = _shard_config.ignored_keys is not None
|
||||
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
|
||||
with tempdir_ctx as tempdir:
|
||||
if should_ignore_keys:
|
||||
with open(index_filename) as f:
|
||||
index = json.load(f)
|
||||
n_original_shards = len(set(index["weight_map"].values()))
|
||||
|
||||
index["weight_map"] = {
|
||||
param_name: filename
|
||||
for param_name, filename in index["weight_map"].items()
|
||||
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
|
||||
}
|
||||
n_loaded_shards = len(set(index["weight_map"].values()))
|
||||
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
|
||||
|
||||
# Replace the original index with a patched JSON, where ignored keys are removed
|
||||
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
|
||||
with open(index_filename, "w") as f:
|
||||
json.dump(index, f)
|
||||
|
||||
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
|
||||
|
||||
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files
|
@ -0,0 +1,88 @@
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind import get_logger
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PTuneConfig:
|
||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||
|
||||
|
||||
class PTuneMixin:
|
||||
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
|
||||
|
||||
def init_prompts(self, config: PretrainedConfig) -> None:
|
||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||
|
||||
with force_non_empty_weights():
|
||||
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
|
||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||
if config.tuning_mode == "deep_ptune":
|
||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||
self.pre_seq_len,
|
||||
config.num_hidden_layers * config.hidden_size,
|
||||
# ^-- TODO: should be num_hidden_layers - 1
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif config.tuning_mode:
|
||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||
|
||||
def set_requires_grad(self, value):
|
||||
for p in self.parameters():
|
||||
p.requires_grad = value
|
||||
|
||||
def get_prompt(self, batch_size):
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||
prompts = self.prompt_embeddings(prefix_tokens)
|
||||
|
||||
if self.config.tuning_mode == "deep_ptune":
|
||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||
intermediate_prompts = intermediate_prompts.view(
|
||||
batch_size,
|
||||
self.pre_seq_len,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.hidden_size
|
||||
# TODO: should be num_hidden_layers - 1
|
||||
)
|
||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||
else:
|
||||
intermediate_prompts = DUMMY
|
||||
|
||||
dtype = self.word_embeddings.weight.dtype
|
||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||
|
||||
|
||||
_original_register_parameter = nn.Module.register_parameter
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_non_empty_weights():
|
||||
"""
|
||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||
The transformers library should replace all meta tensors by empty tensors by itself
|
||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||
|
||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||
"""
|
||||
|
||||
try:
|
||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||
nn.Module.register_parameter = _original_register_parameter
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
@ -1,268 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.bloom import (
|
||||
BloomConfig,
|
||||
BloomForCausalLM,
|
||||
BloomForSequenceClassification,
|
||||
BloomModel,
|
||||
BloomPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.bloom.modeling_utils import LMHead
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
from petals.utils.misc import DUMMY
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
|
||||
"""
|
||||
A bloom config that contains information about DHT peers.
|
||||
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
||||
"""
|
||||
|
||||
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
||||
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
||||
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
||||
|
||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||
|
||||
# This settings matter for running the client with dtype bfloat16 on CPU.
|
||||
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
||||
use_chunked_forward: Union[str, bool] = "auto"
|
||||
chunked_forward_step: int = 16384
|
||||
|
||||
|
||||
original_register_parameter = nn.Module.register_parameter
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_non_empty_weights():
|
||||
"""
|
||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||
The transformers library should replace all meta tensors by empty tensors by itself
|
||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||
|
||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||
"""
|
||||
|
||||
try:
|
||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||
nn.Module.register_parameter = original_register_parameter
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
||||
|
||||
|
||||
class _FromPretrainedDefaultsMixin:
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
*args,
|
||||
low_cpu_mem_usage: Optional[bool] = None,
|
||||
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
if torch_dtype is None:
|
||||
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
||||
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
||||
torch_dtype = "auto"
|
||||
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
|
||||
|
||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||
).replace(
|
||||
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
|
||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
||||
r"^(intermediate_)?prompt_embeddings\.weight$",
|
||||
]
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
||||
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
|
||||
|
||||
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.n_layer = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
# Forbid accumulate grads for embeddings and layernorm
|
||||
self.set_requires_grad(False)
|
||||
|
||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||
|
||||
with force_non_empty_weights():
|
||||
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
||||
"to increase ptune quality"
|
||||
)
|
||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||
if config.tuning_mode == "deep_ptune":
|
||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||
self.pre_seq_len,
|
||||
config.num_hidden_layers * config.hidden_size,
|
||||
# ^-- TODO: should be num_hidden_layers - 1
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif config.tuning_mode:
|
||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||
|
||||
def set_requires_grad(self, value):
|
||||
for p in self.parameters():
|
||||
p.requires_grad = value
|
||||
|
||||
def get_prompt(self, batch_size):
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||
prompts = self.prompt_embeddings(prefix_tokens)
|
||||
|
||||
if self.config.tuning_mode == "deep_ptune":
|
||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||
intermediate_prompts = intermediate_prompts.view(
|
||||
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
||||
)
|
||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||
else:
|
||||
intermediate_prompts = DUMMY
|
||||
|
||||
dtype = self.word_embeddings.weight.dtype
|
||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not (v is None or v is False):
|
||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||
else:
|
||||
hidden_states = self.h(hidden_states)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
||||
)
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.transformer.word_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
return None
|
||||
return self.lm_head
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
||||
assert isinstance(new_embeddings, nn.Embedding)
|
||||
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
||||
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
||||
|
||||
def set_output_embeddings(self, new_lm_head: nn.Linear):
|
||||
with torch.no_grad():
|
||||
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
||||
self.lm_head.bias[...] = new_lm_head.bias
|
||||
|
||||
|
||||
class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
)
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -0,0 +1,2 @@
|
||||
from petals.models.bloom import *
|
||||
from petals.models.llama import *
|
@ -0,0 +1,7 @@
|
||||
from petals.models.bloom.block import WrappedBloomBlock
|
||||
from petals.models.bloom.config import DistributedBloomConfig
|
||||
from petals.models.bloom.model import (
|
||||
DistributedBloomForCausalLM,
|
||||
DistributedBloomForSequenceClassification,
|
||||
DistributedBloomModel,
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Bloom intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||
See commit history for authorship.
|
||||
"""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
|
||||
|
||||
|
||||
class WrappedBloomBlock(BloomBlock):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs
|
||||
):
|
||||
assert attention_mask is None, "Non-causal attention masks are not supported yet"
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
|
||||
return super().forward(
|
||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||
)
|
@ -0,0 +1,35 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.bloom import BloomConfig
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||
from petals.models.bloom.block import WrappedBloomBlock
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from petals.utils.version import get_compatible_model_repo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedBloomBlock
|
||||
attn_class = BloomAttention
|
||||
block_prefix = "h"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
# We need "-petals" for backward compatibility with Petals < 1.2.0
|
||||
dht_prefix = str(model_name_or_path) + "-petals"
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
|
||||
|
||||
AutoDistributedConfig.register(DistributedBloomConfig)
|
@ -0,0 +1,134 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.bloom.config import DistributedBloomConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing
|
||||
)
|
||||
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not (v is None or v is False):
|
||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||
else:
|
||||
hidden_states = self.h(hidden_states)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
+ [r"^lm_head\."] # Missing since they are shared with input embeddings
|
||||
)
|
||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||
)
|
||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedBloomConfig
|
||||
|
||||
def __init__(self, config: DistributedBloomConfig):
|
||||
BloomPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedBloomModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -0,0 +1,7 @@
|
||||
from petals.models.llama.block import WrappedLlamaBlock
|
||||
from petals.models.llama.config import DistributedLlamaConfig
|
||||
from petals.models.llama.model import (
|
||||
DistributedLlamaForCausalLM,
|
||||
DistributedLlamaForSequenceClassification,
|
||||
DistributedLlamaModel,
|
||||
)
|
@ -0,0 +1,87 @@
|
||||
"""
|
||||
LLaMA intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
See commit history for authorship.
|
||||
"""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
|
||||
class WrappedLlamaBlock(LlamaDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
past_key_value = layer_past
|
||||
if past_key_value is not None:
|
||||
past_key_values_length = past_key_value[0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = hidden_states.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
attention_mask = LlamaModel._prepare_decoder_attention_mask(
|
||||
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||
present_key_value, batch_size, seq_length_with_past
|
||||
)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_llama(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
|
||||
value_states = value_states.view(*key_states.shape)
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_llama_to_bloom(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value
|
||||
value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
|
||||
key_states = key_states.view(*value_states.shape)
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
return (key_states, value_states)
|
@ -0,0 +1,35 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||
from petals.models.llama.block import WrappedLlamaBlock
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedLlamaBlock
|
||||
attn_class = LlamaAttention
|
||||
block_prefix = "model.layers"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts
|
||||
dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
|
||||
|
||||
AutoDistributedConfig.register(DistributedLlamaConfig)
|
@ -0,0 +1,152 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.llama.config import DistributedLlamaConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
||||
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."]
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.layers) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.layers = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithPast:
|
||||
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not (v is None or v is False):
|
||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
|
||||
else:
|
||||
hidden_states = self.layers(hidden_states)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||
return self.embed_tokens
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return nn.Identity()
|
||||
|
||||
@property
|
||||
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||
return self.layers
|
||||
|
||||
@property
|
||||
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return self.norm
|
||||
|
||||
|
||||
class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config: DistributedLlamaConfig):
|
||||
LlamaPreTrainedModel.__init__(self, config)
|
||||
self.model = DistributedLlamaModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
||||
|
||||
|
||||
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = (
|
||||
LlamaForSequenceClassification._keys_to_ignore_on_load_missing
|
||||
+ DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||
)
|
||||
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedLlamaConfig
|
||||
|
||||
def __init__(self, config):
|
||||
LlamaPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = DistributedLlamaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
@ -0,0 +1,175 @@
|
||||
"""
|
||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||
- loading files from a local data source (e.g. S3)
|
||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.server.block_utils import resolve_block_dtype
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_pretrained_block(
|
||||
model_name: str,
|
||||
block_index: int,
|
||||
*,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> nn.Module:
|
||||
if config is None:
|
||||
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
|
||||
if cache_dir is None:
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
|
||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
|
||||
with init_empty_weights():
|
||||
block = config.block_class(config)
|
||||
|
||||
block_prefix = f"{config.block_prefix}.{block_index}."
|
||||
state_dict = _load_state_dict_from_repo(
|
||||
model_name,
|
||||
block_prefix,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
|
||||
# dummy load, check that keys match
|
||||
report = block.load_state_dict(state_dict, strict=True)
|
||||
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||
|
||||
for param_name, _ in block.named_parameters():
|
||||
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||
param = state_dict[param_name]
|
||||
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
param = param.to(torch_dtype)
|
||||
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||
|
||||
logger.info(f"Loaded {model_name} block {block_index}, {report}")
|
||||
return block
|
||||
|
||||
|
||||
StateDict = Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def _load_state_dict_from_repo(
|
||||
model_name: str,
|
||||
block_prefix: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
) -> StateDict:
|
||||
index_file = get_file_from_repo(
|
||||
model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
|
||||
)
|
||||
if index_file is not None: # Sharded model
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
filenames = {
|
||||
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
|
||||
}
|
||||
if not filenames:
|
||||
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
|
||||
else: # Non-sharded model
|
||||
filenames = {"pytorch_model.bin"}
|
||||
logger.debug(f"Loading {block_prefix}* from {filenames}")
|
||||
|
||||
state_dict = {}
|
||||
for filename in filenames:
|
||||
shard_state_dict = _load_state_dict_from_file(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
shard_state_dict = {
|
||||
param_name[len(block_prefix) :]: param
|
||||
for param_name, param in shard_state_dict.items()
|
||||
if param_name.startswith(block_prefix)
|
||||
} # Remove unused parameters from memory
|
||||
state_dict.update(shard_state_dict)
|
||||
return state_dict
|
||||
|
||||
|
||||
def _load_state_dict_from_file(
|
||||
model_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
delay: float = 30,
|
||||
) -> StateDict:
|
||||
# First, try to find the weights locally
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
path = get_file_from_repo(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
if path is not None:
|
||||
return torch.load(path, map_location="cpu")
|
||||
except Exception:
|
||||
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
|
||||
|
||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||
while True:
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
url = hf_hub_url(model_name, filename, revision=revision)
|
||||
file_size = get_hf_file_metadata(url, token=use_auth_token).size
|
||||
if file_size is not None:
|
||||
free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
|
||||
|
||||
path = get_file_from_repo(
|
||||
model_name,
|
||||
filename,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
if path is None:
|
||||
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
|
||||
return torch.load(path, map_location="cpu")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
@ -0,0 +1 @@
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
@ -0,0 +1,23 @@
|
||||
from typing import Type
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register()
|
||||
|
||||
|
||||
class AutoDistributedConfig:
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
|
||||
config = AutoConfig.from_pretrained(*args, **kwargs)
|
||||
if config.model_type not in CONFIG_MAPPING:
|
||||
raise ValueError(f"Petals does not support model type {config.model_type}")
|
||||
|
||||
dist_config_class = CONFIG_MAPPING[config.model_type]
|
||||
return dist_config_class.from_pretrained(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def register(config_class: Type[PretrainedConfig]) -> None:
|
||||
assert issubclass(config_class, PretrainedConfig)
|
||||
assert config_class.model_type not in CONFIG_MAPPING
|
||||
|
||||
CONFIG_MAPPING[config_class.model_type] = config_class
|
@ -1,17 +1,16 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from petals.bloom.from_pretrained import load_pretrained_block
|
||||
from petals.client import DistributedBloomConfig
|
||||
from petals.server.block_utils import resolve_block_dtype
|
||||
from petals.server.from_pretrained import load_pretrained_block
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
|
||||
def test_backend_dtype(torch_dtype):
|
||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
||||
block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype)
|
||||
backend_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
|
||||
assert backend_dtype == other_backend_dtype
|
||||
def test_block_dtype(torch_dtype):
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||
block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
|
||||
expected_dtype = resolve_block_dtype(config, torch_dtype)
|
||||
assert all(param.dtype == expected_dtype for param in block.parameters())
|
||||
|
Loading…
Reference in New Issue