From cb3f018f9f0362ff4d2aa77c6950c1b6aabcdc43 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 23 Jun 2023 15:46:10 +0400 Subject: [PATCH] Add LLaMA support (#323) This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present). --- .github/workflows/run-tests.yaml | 66 +---- setup.cfg | 4 +- src/petals/__init__.py | 12 +- src/petals/bloom/__init__.py | 0 src/petals/bloom/block.py | 62 ---- src/petals/bloom/from_pretrained.py | 132 --------- src/petals/cli/config.json | 20 -- src/petals/cli/convert_model.py | 96 ------- src/petals/cli/inference_one_block.py | 2 +- src/petals/cli/run_server.py | 2 +- src/petals/client/__init__.py | 6 - src/petals/client/from_pretrained.py | 94 ++++++ .../modeling_utils.py => client/lm_head.py} | 72 ++--- src/petals/client/ptune.py | 88 ++++++ src/petals/client/remote_model.py | 268 ------------------ src/petals/client/remote_sequential.py | 7 +- src/petals/client/routing/sequence_manager.py | 9 +- src/petals/models/__init__.py | 2 + src/petals/models/bloom/__init__.py | 7 + src/petals/models/bloom/block.py | 32 +++ src/petals/models/bloom/config.py | 35 +++ src/petals/models/bloom/model.py | 134 +++++++++ src/petals/models/llama/__init__.py | 7 + src/petals/models/llama/block.py | 87 ++++++ src/petals/models/llama/config.py | 35 +++ src/petals/models/llama/model.py | 152 ++++++++++ src/petals/server/backend.py | 21 +- src/petals/server/block_utils.py | 10 +- src/petals/server/from_pretrained.py | 175 ++++++++++++ src/petals/server/server.py | 64 +++-- src/petals/server/throughput.py | 22 +- src/petals/utils/__init__.py | 1 + src/petals/utils/auto_config.py | 23 ++ src/petals/utils/convert_block.py | 28 +- src/petals/utils/disk_cache.py | 8 +- src/petals/utils/version.py | 20 +- tests/test_aux_functions.py | 4 +- tests/test_block_exact_match.py | 70 +---- tests/test_chained_calls.py | 4 +- tests/test_dtype.py | 15 +- tests/test_full_model.py | 4 +- tests/test_remote_sequential.py | 18 +- tests/test_sequence_manager.py | 4 +- tests/test_server_stats.py | 2 +- tests/test_tensor_parallel.py | 2 +- 45 files changed, 1073 insertions(+), 853 deletions(-) delete mode 100644 src/petals/bloom/__init__.py delete mode 100644 src/petals/bloom/block.py delete mode 100644 src/petals/bloom/from_pretrained.py delete mode 100644 src/petals/cli/config.json delete mode 100644 src/petals/cli/convert_model.py create mode 100644 src/petals/client/from_pretrained.py rename src/petals/{bloom/modeling_utils.py => client/lm_head.py} (53%) create mode 100644 src/petals/client/ptune.py delete mode 100644 src/petals/client/remote_model.py create mode 100644 src/petals/models/__init__.py create mode 100644 src/petals/models/bloom/__init__.py create mode 100644 src/petals/models/bloom/block.py create mode 100644 src/petals/models/bloom/config.py create mode 100644 src/petals/models/bloom/model.py create mode 100644 src/petals/models/llama/__init__.py create mode 100644 src/petals/models/llama/block.py create mode 100644 src/petals/models/llama/config.py create mode 100644 src/petals/models/llama/model.py create mode 100644 src/petals/server/from_pretrained.py create mode 100644 src/petals/utils/auto_config.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 37edb8f..fbb5b72 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -6,57 +6,8 @@ on: pull_request: jobs: - convert-model: - runs-on: ubuntu-latest - env: - BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }} - timeout-minutes: 15 - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Check if the model is cached - id: cache-model - uses: actions/cache@v3 - with: - path: ~/converted_ok - key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - - name: Set up Python - if: steps.cache-model.outputs.cache-hit != 'true' - uses: actions/setup-python@v3 - with: - python-version: 3.9 - - name: Cache dependencies - if: steps.cache-model.outputs.cache-hit != 'true' - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: Key-v1-3.9-${{ hashFiles('setup.cfg') }} - - name: Install dependencies - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - python -m pip install --upgrade pip - pip install . - - name: Delete any test models older than 1 week - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN - - name: Delete previous version of this model, if exists - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") - python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \ - repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true - - name: Convert model and push to hub - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ - --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \ - --resize_token_embeddings 50000 && touch ~/converted_ok - run-tests: runs-on: ubuntu-latest - needs: convert-model strategy: matrix: python-version: [ '3.7', '3.8', '3.9', '3.10' ] @@ -80,8 +31,7 @@ jobs: pip install .[dev] - name: Test run: | - export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG + export MODEL_NAME=bigscience/bloom-560m export REF_NAME=bigscience/bloom-560m python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ @@ -104,23 +54,19 @@ jobs: --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log & SERVER3_PID=$! - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \ - --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log & - SERVER4_PID=$! - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu --torch_dtype float32 &> server5.log & - SERVER5_PID=$! + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log & + SERVER4_PID=$! tail -n 100 -f server*.log & LOGGER_PID=$! sleep 30 # wait for servers to download layers - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init + kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init pytest tests --durations=0 --durations-min=1.0 -v - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests + kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests - kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID $LOGGER_PID + kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID echo "Done!" diff --git a/setup.cfg b/setup.cfg index 8c237aa..4722c63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,8 @@ install_requires = bitsandbytes==0.38.0.post2 accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 - transformers>=4.25.1,<5.0.0 + tokenizers>=0.13.3 + transformers>=4.30.1,<5.0.0 speedtest-cli==2.1.3 hivemind==1.1.8 tensor_parallel==1.0.23 @@ -43,6 +44,7 @@ install_requires = async-timeout>=4.0.2 cpufeature>=0.2.0 packaging>=20.9 + sentencepiece>=0.1.99 [options.extras_require] dev = diff --git a/src/petals/__init__.py b/src/petals/__init__.py index b50b251..26aa3ab 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,11 +1,21 @@ import os import hivemind +import transformers +from packaging import version from petals.client import * +from petals.models import * +from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.1.5" +__version__ = "1.2.0.dev0" + + +if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): + assert ( + version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0") + ), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0" def _override_bfloat16_mode_default(): diff --git a/src/petals/bloom/__init__.py b/src/petals/bloom/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/petals/bloom/block.py b/src/petals/bloom/block.py deleted file mode 100644 index 9037ee4..0000000 --- a/src/petals/bloom/block.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Bloom intermediate layer -Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b -See commit history for authorship. -""" -import os -from typing import Optional, Tuple - -import torch.nn.quantized.dynamic.modules.linear -import transformers -from packaging import version -from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor - -if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): - assert ( - version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0") - ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0" - - -class WrappedBloomBlock(BloomBlock): - def forward( - self, - hidden_states: torch.Tensor, - *args, - attention_mask: Optional[torch.Tensor] = None, - alibi: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs - ): - assert attention_mask is None - batch_size, seq_length = hidden_states.shape[:2] - past_length = 0 if layer_past is None else layer_past[0].shape[-1] - seq_length_with_past = seq_length + past_length - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - if alibi is None: - alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) - return super().forward( - hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs - ) - - def _prepare_attn_mask( - self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int - ) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = _make_causal_mask( - torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py deleted file mode 100644 index d40b01f..0000000 --- a/src/petals/bloom/from_pretrained.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code. -If necessary, one can rewrite this to implement a different behavior, such as: - - loading files from a local data source (e.g. S3) - - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to ) - - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html ) - -""" -from __future__ import annotations - -import itertools -import time -from typing import Optional, OrderedDict, Union - -import torch -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from hivemind.utils.logging import get_logger -from transformers.modeling_utils import WEIGHTS_NAME -from transformers.models.bloom.configuration_bloom import BloomConfig -from transformers.utils import get_file_from_repo - -from petals.bloom.block import WrappedBloomBlock -from petals.server.block_utils import get_block_size, resolve_block_dtype -from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for - -logger = get_logger(__name__) - -CLIENT_BRANCH = "main" -BLOCK_BRANCH_PREFIX = "block_" - - -def load_pretrained_block( - converted_model_name_or_path: str, - block_index: int, - config: Optional[BloomConfig] = None, - torch_dtype: Union[torch.dtype, str] = "auto", - use_auth_token: Optional[str] = None, - cache_dir: Optional[str] = None, - max_disk_space: Optional[int] = None, -) -> WrappedBloomBlock: - """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it.""" - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - torch_dtype = resolve_block_dtype(config, torch_dtype) - - if config is None: - config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) - if cache_dir is None: - cache_dir = DEFAULT_CACHE_DIR - - with init_empty_weights(): - block = WrappedBloomBlock(config) - - state_dict = _load_state_dict( - converted_model_name_or_path, - block_index, - config, - use_auth_token=use_auth_token, - cache_dir=cache_dir, - max_disk_space=max_disk_space, - ) - - # dummy load, check that keys match - report = block.load_state_dict(state_dict, strict=True) - assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" - - for param_name, _ in block.named_parameters(): - assert param_name in state_dict, f"{param_name} not in state dict" - param = state_dict[param_name] - if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): - param = param.to(torch_dtype) - set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) - - logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}") - return block - - -def _load_state_dict( - pretrained_model_name_or_path: str, - block_index: int, - config: BloomConfig, - *, - use_auth_token: Optional[str] = None, - cache_dir: str, - max_disk_space: Optional[int] = None, - min_backoff: float = 5, -) -> OrderedDict[str, torch.Tensor]: - revision = BLOCK_BRANCH_PREFIX + str(block_index) - - # First, try to find the weights locally - try: - with allow_cache_reads(cache_dir): - archive_file = get_file_from_repo( - pretrained_model_name_or_path, - filename=WEIGHTS_NAME, - revision=revision, - use_auth_token=use_auth_token, - cache_dir=cache_dir, - local_files_only=True, - ) - if archive_file is not None: - return torch.load(archive_file, map_location="cpu") - except Exception: - logger.debug( - f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True - ) - - # If not found, ensure that we have enough disk space to download them (maybe remove something) - for attempt_no in itertools.count(): - try: - with allow_cache_writes(cache_dir): - block_size = get_block_size(config, "disk") - free_disk_space_for( - pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space - ) - - archive_file = get_file_from_repo( - pretrained_model_name_or_path, - filename=WEIGHTS_NAME, - revision=revision, - use_auth_token=use_auth_token, - cache_dir=cache_dir, - local_files_only=False, - ) - return torch.load(archive_file, map_location="cpu") - except Exception as e: - delay = min_backoff * (2**attempt_no) - logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) - time.sleep(delay) - - -DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/cli/config.json b/src/petals/cli/config.json deleted file mode 100644 index ca7ffbb..0000000 --- a/src/petals/cli/config.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "apply_residual_connection_post_layernorm": false, - "attention_dropout": 0.0, - "attention_softmax_in_fp32": true, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_dropout": 0.0, - "initializer_range": 0.02, - "layer_norm_epsilon": 1e-05, - "masked_softmax_fusion": true, - "model_type": "bloom", - "n_embed": 14336, - "n_layer": 70, - "num_attention_heads": 112, - "pretraining_tp": 4, - "slow_but_exact": false, - "transformers_version": "4.20.0.dev0", - "use_cache": true, - "vocab_size": 250880 -} \ No newline at end of file diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py deleted file mode 100644 index 15e12b7..0000000 --- a/src/petals/cli/convert_model.py +++ /dev/null @@ -1,96 +0,0 @@ -import argparse -import os - -import psutil -import torch.backends.quantized -import torch.nn as nn -import transformers -from hivemind.utils.logging import get_logger -from huggingface_hub import HfApi, Repository -from tqdm.auto import tqdm -from transformers.models.bloom.modeling_bloom import BloomModel - -from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP -from petals.client import DistributedBloomConfig - -logger = get_logger(__name__) - - -def main(): - parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.") - - parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained") - parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub") - parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype") - parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder") - parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo") - parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch") - parser.add_argument( - "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix" - ) - parser.add_argument( - "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts" - ) - parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") - parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size") - args = parser.parse_args() - - free_ram_gb = psutil.virtual_memory().available / 2**30 - if args.model == "bigscience/bloom" and free_ram_gb < 400: - logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free") - - assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}" - if os.path.exists(args.output_path) and ( - len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path) - ): - raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory") - - logger.info(f"Loading source model {args.model} (this may take a few minutes)") - config = DistributedBloomConfig.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision - ) - config.dht_prefix = args.output_repo - - model = BloomModel.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype] - ) - if args.resize_token_embeddings: - logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}") - model.resize_token_embeddings(args.resize_token_embeddings) - config.vocab_size = args.resize_token_embeddings - - tokenizer = transformers.AutoTokenizer.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision - ) - os.makedirs(args.output_path, exist_ok=True) - - api = HfApi(token=args.use_auth_token) - api.create_repo(args.output_repo, repo_type="model", exist_ok=True) - repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token) - repo.git_pull() - - transformer_blocks = model.h - logger.info( - f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0" - f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}" - ) - for i, block in enumerate(tqdm(transformer_blocks)): - repo.git_checkout(args.client_branch, create_branch_ok=True) - with repo.commit( - commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True - ): - torch.save(block.state_dict(), "./pytorch_model.bin") - - logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}") - repo.git_checkout(args.client_branch, create_branch_ok=True) - with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True): - model.h = nn.ModuleList() - model.save_pretrained(".") - tokenizer.save_pretrained(".") - config.save_pretrained(".") - - logger.info(f"Converted {args.model} and pushed to {args.output_repo}") - - -if __name__ == "__main__": - main() diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index 01ba1ef..6d53e9b 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -6,7 +6,7 @@ from tqdm.auto import trange from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from petals.bloom.block import BloomBlock +from petals.models.bloom.block import BloomBlock logger = get_logger(__name__) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index fb521ef..4c6f0e5 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -87,7 +87,7 @@ def main(): parser.add_argument('--alloc_timeout', type=float, default=60, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') - parser.add_argument('--revision', type=str, default='main', + parser.add_argument('--revision', type=str, default=None, help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") diff --git a/src/petals/client/__init__.py b/src/petals/client/__init__.py index 5ff26bc..f80c4b1 100644 --- a/src/petals/client/__init__.py +++ b/src/petals/client/__init__.py @@ -1,10 +1,4 @@ from petals.client.inference_session import InferenceSession -from petals.client.remote_model import ( - DistributedBloomConfig, - DistributedBloomForCausalLM, - DistributedBloomForSequenceClassification, - DistributedBloomModel, -) from petals.client.remote_sequential import RemoteSequential from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/petals/client/from_pretrained.py b/src/petals/client/from_pretrained.py new file mode 100644 index 0000000..b8d02c0 --- /dev/null +++ b/src/petals/client/from_pretrained.py @@ -0,0 +1,94 @@ +import contextlib +import json +import os +import re +import tempfile +import threading +from typing import List, Optional, Tuple, Union + +import torch +from hivemind.utils.logging import get_logger +from transformers import BloomPreTrainedModel, modeling_utils + +from petals.utils.version import get_compatible_model_repo + +logger = get_logger(__name__) + + +class FromPretrainedMixin: + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, os.PathLike, None], + *args, + low_cpu_mem_usage: Optional[bool] = None, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + **kwargs, + ): + model_name_or_path = get_compatible_model_repo(model_name_or_path) + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + if torch_dtype is None: + # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, + # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. + torch_dtype = "auto" + + with ignore_keys(cls._keys_to_ignore_on_load_unexpected): + return super().from_pretrained( + model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs + ) + + from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( + "low_cpu_mem_usage(`bool`, *optional*)", + "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", + ).replace( + "torch_dtype (`str` or `torch.dtype`, *optional*)", + 'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)', + ) + + +_shard_config = threading.local() +_shard_config.ignored_keys = None + + +@contextlib.contextmanager +def ignore_keys(patterns: List[str]): + try: + prev_patterns = _shard_config.ignored_keys + _shard_config.ignored_keys = patterns + yield + finally: + _shard_config.ignored_keys = prev_patterns + + +def patched_get_checkpoint_shard_files( + pretrained_model_name_or_path, index_filename, *args, **kwargs +) -> Tuple[List[str], dict]: + """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys.""" + + should_ignore_keys = _shard_config.ignored_keys is not None + tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext() + with tempdir_ctx as tempdir: + if should_ignore_keys: + with open(index_filename) as f: + index = json.load(f) + n_original_shards = len(set(index["weight_map"].values())) + + index["weight_map"] = { + param_name: filename + for param_name, filename in index["weight_map"].items() + if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys) + } + n_loaded_shards = len(set(index["weight_map"].values())) + logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}") + + # Replace the original index with a patched JSON, where ignored keys are removed + index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json") + with open(index_filename, "w") as f: + json.dump(index, f) + + return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) + + +original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files +modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/client/lm_head.py similarity index 53% rename from src/petals/bloom/modeling_utils.py rename to src/petals/client/lm_head.py index eddbb9d..ddd2887 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/client/lm_head.py @@ -1,10 +1,6 @@ -""" -PyTorch BLOOM model that implements several memory-efficient modes. -Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b -See commit history for authorship. -""" - +import dataclasses import platform +from typing import Optional, Union import psutil import torch @@ -12,21 +8,30 @@ import torch.nn.functional as F import torch.utils.checkpoint from hivemind import get_logger from torch import nn -from transformers import BloomConfig +from transformers import PretrainedConfig logger = get_logger(__name__) -class LMHead(nn.Module): - """ - The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input - embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. - In addition, it provides an effcient way to deal with half-precision word embeddings on CPU. - """ +@dataclasses.dataclass +class LMHeadConfig: + # This settings matter for running the client with dtype bfloat16 on CPU. + # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. + use_chunked_forward: Union[str, bool] = "auto" + chunked_forward_step: int = 16384 + - def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding): +class LMHead(nn.Module): + def __init__(self, config: PretrainedConfig): super().__init__() - self.word_embeddings = word_embeddings + + if not config.tie_word_embeddings: + self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False)) + else: + self.weight = None # Will be set to get_input_embeddings().weight during loading the model + self.bias = None + self.in_features = config.hidden_size # Similar to nn.Linear attributes + self.out_features = config.vocab_size self.use_chunked_forward = config.use_chunked_forward if self.use_chunked_forward == "auto": @@ -42,35 +47,17 @@ class LMHead(nn.Module): self.chunked_forward_step = config.chunked_forward_step self._bf16_warning_shown = False - @property - def in_features(self) -> int: - return self.word_embeddings.num_embeddings - - @property - def out_features(self) -> int: - return self.word_embeddings.embedding_dim - - @property - def weight(self): - return self.word_embeddings.weight - - @property - def bias(self): - return None - def forward(self, hidden_states): - word_embeddings = self.word_embeddings.weight - if ( - word_embeddings.dtype in [torch.float16, torch.bfloat16] - and word_embeddings.device.type == "cpu" + self.weight.dtype in [torch.float16, torch.bfloat16] + and self.weight.device.type == "cpu" and self.use_chunked_forward ): lm_logits = self.chunked_forward(hidden_states) else: # Switch dtype in case word_embeddings are fp16/bf16 - hidden_states = hidden_states.to(word_embeddings.dtype) - lm_logits = F.linear(hidden_states, word_embeddings) + hidden_states = hidden_states.to(self.weight.dtype) + lm_logits = F.linear(hidden_states, self.weight) return lm_logits def chunked_forward(self, hidden_states): @@ -80,20 +67,17 @@ class LMHead(nn.Module): assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" if not self._bf16_warning_shown: - if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: + if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: logger.warning( "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. " "Consider loading the model with torch_dtype='float32'" ) self._bf16_warning_shown = True - word_embeddings = self.word_embeddings.weight - num_embeddings = self.word_embeddings.num_embeddings - hidden_states = hidden_states.float() - output = torch.empty(*hidden_states.shape[:-1], num_embeddings) + output = torch.empty(*hidden_states.shape[:-1], self.out_features) - for i in range(0, num_embeddings, self.chunked_forward_step): - chunk = word_embeddings[i : i + self.chunked_forward_step].float() + for i in range(0, self.out_features, self.chunked_forward_step): + chunk = self.weight[i : i + self.chunked_forward_step].float() output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk) return output diff --git a/src/petals/client/ptune.py b/src/petals/client/ptune.py new file mode 100644 index 0000000..5cf613c --- /dev/null +++ b/src/petals/client/ptune.py @@ -0,0 +1,88 @@ +import dataclasses +from contextlib import contextmanager +from typing import Optional + +import torch +import torch.nn as nn +from hivemind import get_logger +from transformers import PretrainedConfig + +from petals.utils.misc import DUMMY + +logger = get_logger(__name__) + + +@dataclasses.dataclass +class PTuneConfig: + pre_seq_len: int = 0 # a number of tokens for prompt tuning. + tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"] + + +class PTuneMixin: + _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"] + + def init_prompts(self, config: PretrainedConfig) -> None: + if config.tuning_mode and "ptune" in config.tuning_mode: + assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0" + self.pre_seq_len = config.pre_seq_len + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + + with force_non_empty_weights(): + # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality + self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32) + if config.tuning_mode == "deep_ptune": + self.intermediate_prompt_embeddings = nn.Embedding( + self.pre_seq_len, + config.num_hidden_layers * config.hidden_size, + # ^-- TODO: should be num_hidden_layers - 1 + dtype=torch.float32, + ) + elif config.tuning_mode: + raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") + + def set_requires_grad(self, value): + for p in self.parameters(): + p.requires_grad = value + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) + prompts = self.prompt_embeddings(prefix_tokens) + + if self.config.tuning_mode == "deep_ptune": + intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens) + intermediate_prompts = intermediate_prompts.view( + batch_size, + self.pre_seq_len, + self.config.num_hidden_layers, + self.config.hidden_size + # TODO: should be num_hidden_layers - 1 + ) + intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3]) + else: + intermediate_prompts = DUMMY + + dtype = self.word_embeddings.weight.dtype + return prompts.to(dtype), intermediate_prompts.to(dtype) + + +_original_register_parameter = nn.Module.register_parameter + + +@contextmanager +def force_non_empty_weights(): + """ + This context manager allows to bypass the accelerate.init_empty_weights() context manager + (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True. + The transformers library should replace all meta tensors by empty tensors by itself + but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`). + + [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515 + """ + + try: + possibly_patched_register_parameter = nn.Module.register_parameter + nn.Module.register_parameter = _original_register_parameter + yield + finally: + nn.Module.register_parameter = possibly_patched_register_parameter diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py deleted file mode 100644 index b556714..0000000 --- a/src/petals/client/remote_model.py +++ /dev/null @@ -1,268 +0,0 @@ -from contextlib import contextmanager -from typing import List, Optional, Union - -import hivemind -import torch -import torch.nn as nn -from hivemind.utils.logging import get_logger -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bloom import ( - BloomConfig, - BloomForCausalLM, - BloomForSequenceClassification, - BloomModel, - BloomPreTrainedModel, -) - -from petals.bloom.modeling_utils import LMHead -from petals.client.remote_generation import RemoteGenerationMixin -from petals.client.remote_sequential import RemoteSequential -from petals.client.routing.sequence_manager import SequenceManagerConfig -from petals.constants import PUBLIC_INITIAL_PEERS -from petals.utils.misc import DUMMY - -logger = get_logger(__name__) - - -class DistributedBloomConfig(BloomConfig, SequenceManagerConfig): - """ - A bloom config that contains information about DHT peers. - To create a distributed model, one must provide dht_prefix and either initial_peers or dht. - """ - - initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT - dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) - daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers - - pre_seq_len: int = 0 # a number of tokens for prompt tuning. - tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"] - - # This settings matter for running the client with dtype bfloat16 on CPU. - # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. - use_chunked_forward: Union[str, bool] = "auto" - chunked_forward_step: int = 16384 - - -original_register_parameter = nn.Module.register_parameter - - -@contextmanager -def force_non_empty_weights(): - """ - This context manager allows to bypass the accelerate.init_empty_weights() context manager - (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True. - The transformers library should replace all meta tensors by empty tensors by itself - but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`). - - [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515 - """ - - try: - possibly_patched_register_parameter = nn.Module.register_parameter - nn.Module.register_parameter = original_register_parameter - yield - finally: - nn.Module.register_parameter = possibly_patched_register_parameter - - -class _FromPretrainedDefaultsMixin: - @classmethod - def from_pretrained( - cls, - *args, - low_cpu_mem_usage: Optional[bool] = None, - torch_dtype: Optional[Union[str, torch.dtype]] = None, - **kwargs, - ): - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - if torch_dtype is None: - # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, - # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. - torch_dtype = "auto" - return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs) - - from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( - "low_cpu_mem_usage(`bool`, *optional*)", - "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", - ).replace( - "torch_dtype (`str` or `torch.dtype`, *optional*)", - 'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)', - ) - - -class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel): - """BloomModel, but all transformer layers are hosted by the swarm""" - - _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [ - r"^(intermediate_)?prompt_embeddings\.weight$", - ] - - config_class = DistributedBloomConfig - - def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None): - assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." - assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" - - n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization - super().__init__(config) - assert len(self.h) == 0 - config.n_layer = n_layer - - self.h = RemoteSequential(config, dht=dht) - - # Forbid accumulate grads for embeddings and layernorm - self.set_requires_grad(False) - - if config.tuning_mode and "ptune" in config.tuning_mode: - assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0" - self.pre_seq_len = config.pre_seq_len - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - - with force_non_empty_weights(): - if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16): - logger.info( - "Prompt embeddings and their optimizer statistics will be kept in float32 " - "to increase ptune quality" - ) - self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32) - if config.tuning_mode == "deep_ptune": - self.intermediate_prompt_embeddings = nn.Embedding( - self.pre_seq_len, - config.num_hidden_layers * config.hidden_size, - # ^-- TODO: should be num_hidden_layers - 1 - dtype=torch.float32, - ) - elif config.tuning_mode: - raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") - - def set_requires_grad(self, value): - for p in self.parameters(): - p.requires_grad = value - - def get_prompt(self, batch_size): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) - prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) - prompts = self.prompt_embeddings(prefix_tokens) - - if self.config.tuning_mode == "deep_ptune": - intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens) - intermediate_prompts = intermediate_prompts.view( - batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1 - ) - intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3]) - else: - intermediate_prompts = DUMMY - - dtype = self.word_embeddings.weight.dtype - return prompts.to(dtype), intermediate_prompts.to(dtype) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - assert attention_mask is None, "DistributedBloomModel does not support attention masks right now" - - for k, v in kwargs.items(): - if not (v is None or v is False): - logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - batch_size = inputs_embeds.shape[0] - prompts, intermediate_prompts = self.get_prompt(batch_size) - inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - output_shape = input_shape + (hidden_states.size(-1),) - - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - hidden_states = self.h(hidden_states, prompts=intermediate_prompts) - else: - hidden_states = self.h(hidden_states) - - # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - hidden_states = hidden_states[:, self.pre_seq_len :] - - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM): - """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" - - _keys_to_ignore_on_load_missing = ( - BloomForCausalLM._keys_to_ignore_on_load_missing - + DistributedBloomModel._keys_to_ignore_on_load_missing - + [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings - ) - - config_class = DistributedBloomConfig - - def __init__(self, config: DistributedBloomConfig): - BloomPreTrainedModel.__init__(self, config) - self.transformer = DistributedBloomModel(config) - self.lm_head = LMHead(config, self.transformer.word_embeddings) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.transformer.word_embeddings - - def get_output_embeddings(self): - if self.config.tie_word_embeddings: - return None - return self.lm_head - - def set_input_embeddings(self, new_embeddings: nn.Embedding): - assert isinstance(new_embeddings, nn.Embedding) - self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings - assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings - - def set_output_embeddings(self, new_lm_head: nn.Linear): - with torch.no_grad(): - self.lm_head.word_embeddings.weight[...] = new_lm_head.weight - self.lm_head.bias[...] = new_lm_head.bias - - -class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification): - _keys_to_ignore_on_load_missing = ( - BloomForSequenceClassification._keys_to_ignore_on_load_missing - + DistributedBloomModel._keys_to_ignore_on_load_missing - ) - - config_class = DistributedBloomConfig - - def __init__(self, config: DistributedBloomConfig): - BloomPreTrainedModel.__init__(self, config) - self.num_labels = config.num_labels - - self.transformer = DistributedBloomModel(config) - self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) - - # Initialize weights and apply final processing - self.post_init() diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 39811e3..745b5c1 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -6,9 +6,8 @@ import torch from hivemind import DHT, get_logger from torch import nn -import petals.client from petals.client.inference_session import InferenceSession -from petals.client.routing.sequence_manager import RemoteSequenceManager +from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER from petals.utils.misc import DUMMY @@ -23,7 +22,7 @@ class RemoteSequential(nn.Module): def __init__( self, - config: petals.client.DistributedBloomConfig, + config: SequenceManagerConfig, *, sequence_manager: Optional[RemoteSequenceManager] = None, dht: Optional[DHT] = None, @@ -40,7 +39,7 @@ class RemoteSequential(nn.Module): if start_block is None: start_block = 0 if end_block is None: - end_block = self.config.n_layer + end_block = self.config.num_hidden_layers block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) self.sequence_manager = sequence_manager diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 6ac7bb0..1a31d66 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -20,6 +20,7 @@ from hivemind.utils.logging import get_logger import petals.dht_utils from petals.client.routing.sequence_info import RemoteSequenceInfo from petals.client.routing.spending_policy import NoSpendingPolicy +from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler @@ -28,6 +29,10 @@ logger = get_logger(__name__) @dataclasses.dataclass class SequenceManagerConfig: + initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT + dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name) + daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers + allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests @@ -73,6 +78,8 @@ class RemoteSequenceManager: dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, ): + assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" + assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." assert len(block_uids) > 0, "Sequences must contain at least one block" self.config = config @@ -84,7 +91,7 @@ class RemoteSequenceManager: dht = DHT( initial_peers=config.initial_peers, client_mode=True, - num_workers=config.n_layer, + num_workers=config.num_hidden_layers, startup_timeout=config.daemon_startup_timeout, start=True, ) diff --git a/src/petals/models/__init__.py b/src/petals/models/__init__.py new file mode 100644 index 0000000..acb4d38 --- /dev/null +++ b/src/petals/models/__init__.py @@ -0,0 +1,2 @@ +from petals.models.bloom import * +from petals.models.llama import * diff --git a/src/petals/models/bloom/__init__.py b/src/petals/models/bloom/__init__.py new file mode 100644 index 0000000..911974b --- /dev/null +++ b/src/petals/models/bloom/__init__.py @@ -0,0 +1,7 @@ +from petals.models.bloom.block import WrappedBloomBlock +from petals.models.bloom.config import DistributedBloomConfig +from petals.models.bloom.model import ( + DistributedBloomForCausalLM, + DistributedBloomForSequenceClassification, + DistributedBloomModel, +) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py new file mode 100644 index 0000000..f246bd8 --- /dev/null +++ b/src/petals/models/bloom/block.py @@ -0,0 +1,32 @@ +""" +Bloom intermediate layer +Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b +See commit history for authorship. +""" +from typing import Optional, Tuple + +import torch +from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor + + +class WrappedBloomBlock(BloomBlock): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs + ): + assert attention_mask is None, "Non-causal attention masks are not supported yet" + batch_size, seq_length = hidden_states.shape[:2] + past_length = 0 if layer_past is None else layer_past[0].shape[-1] + seq_length_with_past = seq_length + past_length + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length) + return super().forward( + hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs + ) diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py new file mode 100644 index 0000000..57c3e7b --- /dev/null +++ b/src/petals/models/bloom/config.py @@ -0,0 +1,35 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.bloom import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention + +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.client.routing.sequence_manager import SequenceManagerConfig +from petals.models.bloom.block import WrappedBloomBlock +from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.version import get_compatible_model_repo + +logger = get_logger(__name__) + + +class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedBloomBlock + attn_class = BloomAttention + block_prefix = "h" + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + # We need "-petals" for backward compatibility with Petals < 1.2.0 + dht_prefix = str(model_name_or_path) + "-petals" + logger.info(f"Using DHT prefix: {dht_prefix}") + return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + + +AutoDistributedConfig.register(DistributedBloomConfig) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py new file mode 100644 index 0000000..fae9faf --- /dev/null +++ b/src/petals/models/bloom/model.py @@ -0,0 +1,134 @@ +from typing import Optional + +import hivemind +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin +from petals.client.remote_sequential import RemoteSequential +from petals.models.bloom.config import DistributedBloomConfig + +logger = get_logger(__name__) + + +class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): + """BloomModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = ( + BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing + ) + _keys_to_ignore_on_load_unexpected = [r"^h\."] + + config_class = DistributedBloomConfig + + def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.h) == 0 + config.num_hidden_layers = n_layer + + self.h = RemoteSequential(config, dht=dht) + + self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now" + + for k, v in kwargs.items(): + if not (v is None or v is False): + logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + output_shape = input_shape + (hidden_states.size(-1),) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = self.h(hidden_states, prompts=intermediate_prompts) + else: + hidden_states = self.h(hidden_states) + + # Remove prefix + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM): + _keys_to_ignore_on_load_missing = ( + BloomForCausalLM._keys_to_ignore_on_load_missing + + DistributedBloomModel._keys_to_ignore_on_load_missing + + [r"^lm_head\."] # Missing since they are shared with input embeddings + ) + _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedBloomConfig + + def __init__(self, config: DistributedBloomConfig): + BloomPreTrainedModel.__init__(self, config) + self.transformer = DistributedBloomModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + +class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification): + _keys_to_ignore_on_load_missing = ( + BloomForSequenceClassification._keys_to_ignore_on_load_missing + + DistributedBloomModel._keys_to_ignore_on_load_missing + ) + _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedBloomConfig + + def __init__(self, config: DistributedBloomConfig): + BloomPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.transformer = DistributedBloomModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) + + # Initialize weights and apply final processing + self.post_init() diff --git a/src/petals/models/llama/__init__.py b/src/petals/models/llama/__init__.py new file mode 100644 index 0000000..8156939 --- /dev/null +++ b/src/petals/models/llama/__init__.py @@ -0,0 +1,7 @@ +from petals.models.llama.block import WrappedLlamaBlock +from petals.models.llama.config import DistributedLlamaConfig +from petals.models.llama.model import ( + DistributedLlamaForCausalLM, + DistributedLlamaForSequenceClassification, + DistributedLlamaModel, +) diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py new file mode 100644 index 0000000..2f07188 --- /dev/null +++ b/src/petals/models/llama/block.py @@ -0,0 +1,87 @@ +""" +LLaMA intermediate layer +Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +See commit history for authorship. +""" +from typing import Optional, Tuple + +import torch +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + + +class WrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + batch_size, seq_length, _ = hidden_states.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + past_key_value = layer_past + if past_key_value is not None: + past_key_values_length = past_key_value[0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) + + if position_ids is None: + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = LlamaModel._prepare_decoder_attention_mask( + None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_llama_to_bloom( + present_key_value, batch_size, seq_length_with_past + ) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_llama( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + key_states = key_states.permute(0, 2, 1) + key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + value_states = value_states.view(*key_states.shape) + return (key_states, value_states) + + def _reorder_cache_from_llama_to_bloom( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + key_states = key_states.view(*value_states.shape) + key_states = key_states.permute(0, 2, 1) + return (key_states, value_states) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py new file mode 100644 index 0000000..a7e6681 --- /dev/null +++ b/src/petals/models/llama/config.py @@ -0,0 +1,35 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention + +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.client.routing.sequence_manager import SequenceManagerConfig +from petals.models.llama.block import WrappedLlamaBlock +from petals.utils.auto_config import AutoDistributedConfig + +logger = get_logger(__name__) + + +class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedLlamaBlock + attn_class = LlamaAttention + block_prefix = "model.layers" + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + dht_prefix = str(model_name_or_path) + if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts + dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :] + logger.info(f"Using DHT prefix: {dht_prefix}") + return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + + +AutoDistributedConfig.register(DistributedLlamaConfig) diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py new file mode 100644 index 0000000..37b4683 --- /dev/null +++ b/src/petals/models/llama/model.py @@ -0,0 +1,152 @@ +from typing import Optional + +import hivemind +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin +from petals.client.remote_sequential import RemoteSequential +from petals.models.llama.config import DistributedLlamaConfig + +logger = get_logger(__name__) + + +class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): + """LlamaModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."] + + config_class = DistributedLlamaConfig + + def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.layers) == 0 + config.num_hidden_layers = n_layer + + self.layers = RemoteSequential(config, dht=dht) + + self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now" + + for k, v in kwargs.items(): + if not (v is None or v is False): + logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + + hidden_states = inputs_embeds + output_shape = input_shape + (hidden_states.size(-1),) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = self.layers(hidden_states, prompts=intermediate_prompts) + else: + hidden_states = self.layers(hidden_states) + + # Remove prefix + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + @property + def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin + return self.embed_tokens + + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin + return nn.Identity() + + @property + def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin + return self.layers + + @property + def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin + return self.norm + + +class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM): + _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedLlamaConfig + + def __init__(self, config: DistributedLlamaConfig): + LlamaPreTrainedModel.__init__(self, config) + self.model = DistributedLlamaModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + @property + def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin + return self.model + + +class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification): + _keys_to_ignore_on_load_missing = ( + LlamaForSequenceClassification._keys_to_ignore_on_load_missing + + DistributedLlamaModel._keys_to_ignore_on_load_missing + ) + _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedLlamaConfig + + def __init__(self, config): + LlamaPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.model = DistributedLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @property + def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin + return self.model diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 76dc52b..adcd617 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -1,4 +1,3 @@ -"""Code for serving bloom blocks via hivemind-server""" from __future__ import annotations from collections import Counter @@ -12,8 +11,7 @@ from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger from tensor_parallel import TensorParallel from tensor_parallel.tensor_parallel import PerDeviceTensors -from transformers import BloomConfig -from transformers.models.bloom.modeling_bloom import BloomAttention +from transformers import PretrainedConfig from petals.data_structures import InferenceMetadata from petals.server.memory_cache import MemoryCache @@ -24,17 +22,19 @@ logger = get_logger(__name__) class TransformerBackend(ModuleBackend): - """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference""" + """A wrapper for a transformer block that can process requests for forward, backward and inference""" - def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): + def __init__( + self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs + ): super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) self.config = config self.memory_cache = memory_cache for name, param in self.module.named_parameters(): - assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" + assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" for name, buf in self.module.named_buffers(): - assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" + assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" max_batch_size = self.forward_pool.max_batch_size device = self.module.devices[self.module.output_device_index] @@ -52,9 +52,10 @@ class TransformerBackend(ModuleBackend): self.shard_num_heads = [] for shard in self.module.module_shards: for submodule in shard.modules(): - if isinstance(submodule, BloomAttention): + if isinstance(submodule, config.attn_class): self.shard_num_heads.append(submodule.num_heads) - assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head + assert len(self.shard_num_heads) == len(self.module.devices) + assert sum(self.shard_num_heads) == config.num_attention_heads self.inference_schema = ( ( @@ -71,7 +72,7 @@ class TransformerBackend(ModuleBackend): def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: """Create tensor descriptors for attention cache tensors used during inference_step""" - head_dim = self.config.hidden_size // self.config.n_head + head_dim = self.config.hidden_size // self.config.num_attention_heads cache_tensors = [] for device, num_heads in zip(self.module.devices, self.shard_num_heads): keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index 8d59d18..a6af3b0 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -2,12 +2,10 @@ from typing import Optional, Union import torch from accelerate import init_empty_weights -from transformers import BloomConfig +from transformers import PretrainedConfig -from petals.bloom.block import WrappedBloomBlock - -def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: +def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" if dtype not in ("auto", None): return dtype @@ -17,7 +15,7 @@ def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> def get_block_size( - config: BloomConfig, + config: PretrainedConfig, location: str, *, dtype: Optional[Union[str, torch.dtype]] = None, @@ -30,7 +28,7 @@ def get_block_size( ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' with init_empty_weights(include_buffers=True): - block = WrappedBloomBlock(config) + block = config.block_class(config) n_params = sum(param.numel() for param in block.parameters()) if location == "memory" and load_in_8bit: diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py new file mode 100644 index 0000000..aab8a9e --- /dev/null +++ b/src/petals/server/from_pretrained.py @@ -0,0 +1,175 @@ +""" +Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code. +If necessary, one can rewrite this to implement a different behavior, such as: + - loading files from a local data source (e.g. S3) + - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to ) + - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html ) + +""" +import json +import time +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from hivemind.utils.logging import get_logger +from huggingface_hub import get_hf_file_metadata, hf_hub_url +from transformers import PretrainedConfig +from transformers.utils import get_file_from_repo + +from petals.server.block_utils import resolve_block_dtype +from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for + +logger = get_logger(__name__) + + +def load_pretrained_block( + model_name: str, + block_index: int, + *, + config: Optional[PretrainedConfig] = None, + torch_dtype: Union[torch.dtype, str] = "auto", + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + max_disk_space: Optional[int] = None, +) -> nn.Module: + if config is None: + config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token) + if cache_dir is None: + cache_dir = DEFAULT_CACHE_DIR + + assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + torch_dtype = resolve_block_dtype(config, torch_dtype) + + with init_empty_weights(): + block = config.block_class(config) + + block_prefix = f"{config.block_prefix}.{block_index}." + state_dict = _load_state_dict_from_repo( + model_name, + block_prefix, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + max_disk_space=max_disk_space, + ) + + # dummy load, check that keys match + report = block.load_state_dict(state_dict, strict=True) + assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" + + for param_name, _ in block.named_parameters(): + assert param_name in state_dict, f"{param_name} not in state dict" + param = state_dict[param_name] + if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + param = param.to(torch_dtype) + set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) + + logger.info(f"Loaded {model_name} block {block_index}, {report}") + return block + + +StateDict = Dict[str, torch.Tensor] + + +def _load_state_dict_from_repo( + model_name: str, + block_prefix: str, + *, + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: str, + max_disk_space: Optional[int] = None, +) -> StateDict: + index_file = get_file_from_repo( + model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir + ) + if index_file is not None: # Sharded model + with open(index_file) as f: + index = json.load(f) + filenames = { + filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix) + } + if not filenames: + raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}") + else: # Non-sharded model + filenames = {"pytorch_model.bin"} + logger.debug(f"Loading {block_prefix}* from {filenames}") + + state_dict = {} + for filename in filenames: + shard_state_dict = _load_state_dict_from_file( + model_name, + filename, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + max_disk_space=max_disk_space, + ) + shard_state_dict = { + param_name[len(block_prefix) :]: param + for param_name, param in shard_state_dict.items() + if param_name.startswith(block_prefix) + } # Remove unused parameters from memory + state_dict.update(shard_state_dict) + return state_dict + + +def _load_state_dict_from_file( + model_name: str, + filename: str, + *, + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: str, + max_disk_space: Optional[int] = None, + delay: float = 30, +) -> StateDict: + # First, try to find the weights locally + try: + with allow_cache_reads(cache_dir): + path = get_file_from_repo( + model_name, + filename, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=True, + ) + if path is not None: + return torch.load(path, map_location="cpu") + except Exception: + logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True) + + # If not found, ensure that we have enough disk space to download them (maybe remove something) + while True: + try: + with allow_cache_writes(cache_dir): + url = hf_hub_url(model_name, filename, revision=revision) + file_size = get_hf_file_metadata(url, token=use_auth_token).size + if file_size is not None: + free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) + else: + logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}") + + path = get_file_from_repo( + model_name, + filename, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=False, + ) + if path is None: + raise RuntimeError(f"File {filename} does not exist in repo {model_name}") + return torch.load(path, map_location="cpu") + except Exception as e: + logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) + time.sleep(delay) + + +DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e424fb5..75a999e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -14,21 +14,23 @@ from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger -from transformers import BloomConfig +from transformers import PretrainedConfig -from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype +from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.throughput import get_dtype_name, get_server_throughput +from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR +from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -53,7 +55,7 @@ class Server: max_batch_size: int = 2048, inference_max_length: int = 2048, torch_dtype: str = "auto", - revision: str = "main", + revision: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, attn_cache_tokens: int = 8192, @@ -83,25 +85,32 @@ class Server: ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" + converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path) self.converted_model_name_or_path = converted_model_name_or_path + self.num_handlers = num_handlers self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.inference_max_length = inference_max_length self.compression = compression self.stats_report_interval, self.update_period = stats_report_interval, update_period self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads - self.use_auth_token = use_auth_token + self.revision, self.use_auth_token = revision, use_auth_token if custom_module_path is not None: add_custom_models_from_file(custom_module_path) + self.block_config = AutoDistributedConfig.from_pretrained( + converted_model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, + ) + if prefix is None: - prefix = converted_model_name_or_path - assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( - f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " - f"Please specify --prefix manually when starting a server" - ) - logger.debug(f"Automatic dht prefix: {prefix}") + prefix = self.block_config.dht_prefix + assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( + f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. " + f"Please specify another --prefix manually when starting a server" + ) self.prefix = prefix if expiration is None: @@ -111,12 +120,9 @@ class Server: self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout - self.block_config = BloomConfig.from_pretrained( - converted_model_name_or_path, - use_auth_token=use_auth_token, - revision=revision, - ) - self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] + self.module_uids = [ + f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers) + ] if dht_client_mode is None: is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) @@ -125,7 +131,7 @@ class Server: self.dht = DHT( initial_peers=initial_peers, start=True, - num_workers=self.block_config.n_layer, + num_workers=self.block_config.num_hidden_layers, use_relay=use_relay, use_auto_relay=use_auto_relay, client_mode=dht_client_mode, @@ -161,10 +167,10 @@ class Server: if load_in_8bit is None: load_in_8bit = device.type == "cuda" self.load_in_8bit = load_in_8bit - logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") + logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") - max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens - self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8 + cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens + self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: @@ -192,6 +198,7 @@ class Server: assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: throughput = get_server_throughput( + converted_model_name_or_path, self.block_config, device, torch_dtype, @@ -239,11 +246,12 @@ class Server: num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" + num_blocks = min(num_blocks, self.block_config.num_hidden_layers) logger.info( f"Server will fill all your GPU memory with {num_blocks} transformer blocks. " f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" ) - return min(num_blocks, self.block_config.n_layer) + return num_blocks def run(self): while True: @@ -274,6 +282,7 @@ class Server: step_timeout=self.step_timeout, prefetch_batches=self.prefetch_batches, sender_threads=self.sender_threads, + revision=self.revision, use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, @@ -352,7 +361,7 @@ class ModuleContainer(threading.Thread): dht: DHT, prefix: str, converted_model_name_or_path: str, - block_config: BloomConfig, + block_config: PretrainedConfig, attn_cache_bytes: int, alloc_timeout: float, throughput: float, @@ -366,6 +375,7 @@ class ModuleContainer(threading.Thread): compression: CompressionType, update_period: float, expiration: Optional[float], + revision: Optional[str], use_auth_token: Optional[str], load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], @@ -394,14 +404,14 @@ class ModuleContainer(threading.Thread): block = load_pretrained_block( converted_model_name_or_path, block_index, - block_config, + config=block_config, torch_dtype=torch_dtype, + revision=revision, use_auth_token=use_auth_token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) - blocks[module_uid] = TransformerBackend( module_uid, block, @@ -564,13 +574,9 @@ class ModuleContainer(threading.Thread): self.ready.clear() + logger.debug("Shutting down connection handlers") for handler in self.conn_handlers: handler.shutdown() - logger.debug("Connection handlers terminated") - - if self.checkpoint_saver is not None: - self.checkpoint_saver.stop.set() - self.checkpoint_saver.join() logger.debug(f"Shutting down pools") for pool in self.runtime.pools: diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index dbefb35..2ee1ca1 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -5,15 +5,13 @@ import multiprocessing as mp import os import time from collections import Counter -from hashlib import sha256 from pathlib import Path from typing import Dict, Optional, Sequence, Union import torch from hivemind.utils.logging import get_logger -from transformers import BloomConfig +from transformers import PretrainedConfig -from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_block import convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -35,7 +33,8 @@ if not hasattr(speedtest, "Speedtest"): def get_server_throughput( - config: BloomConfig, + model_name: str, + config: PretrainedConfig, device: torch.device, dtype: Union[str, torch.dtype], *, @@ -59,7 +58,7 @@ def get_server_throughput( fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # The OS will release the lock when lock_fd is closed or the process is killed - cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}" + cache_key = f"model_{model_name}" cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}" if len(tensor_parallel_devices) > 1: @@ -101,7 +100,7 @@ def get_server_throughput( def measure_throughput_info( - config: BloomConfig, + config: PretrainedConfig, device: torch.device, dtype: torch.dtype, *, @@ -127,7 +126,7 @@ def measure_throughput_info( return throughput_info -def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]: +def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]: pipe_recv, pipe_send = mp.Pipe(duplex=False) process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) process.start() @@ -160,7 +159,7 @@ def _measure_bits_per_second(pipe_send: mp.Pipe): def measure_compute_rps( - config: BloomConfig, + config: PretrainedConfig, device: torch.device, dtype: torch.dtype, *, @@ -172,7 +171,7 @@ def measure_compute_rps( if not tensor_parallel_devices: tensor_parallel_devices = (device,) with torch.inference_mode(): - block = WrappedBloomBlock(config).to(dtype) + block = config.block_class(config).to(dtype) block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) cache = None @@ -203,4 +202,7 @@ def get_device_name(device: torch.device) -> str: def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: - return "8-bit" if load_in_8bit else str(dtype) + name = str(dtype) + if load_in_8bit: + name += ", 8-bit quantized" + return name diff --git a/src/petals/utils/__init__.py b/src/petals/utils/__init__.py index e69de29..654e98c 100644 --- a/src/petals/utils/__init__.py +++ b/src/petals/utils/__init__.py @@ -0,0 +1 @@ +from petals.utils.auto_config import AutoDistributedConfig diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py new file mode 100644 index 0000000..b6fca41 --- /dev/null +++ b/src/petals/utils/auto_config.py @@ -0,0 +1,23 @@ +from typing import Type + +from transformers import AutoConfig, PretrainedConfig + +CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register() + + +class AutoDistributedConfig: + @classmethod + def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig: + config = AutoConfig.from_pretrained(*args, **kwargs) + if config.model_type not in CONFIG_MAPPING: + raise ValueError(f"Petals does not support model type {config.model_type}") + + dist_config_class = CONFIG_MAPPING[config.model_type] + return dist_config_class.from_pretrained(*args, **kwargs) + + @staticmethod + def register(config_class: Type[PretrainedConfig]) -> None: + assert issubclass(config_class, PretrainedConfig) + assert config_class.model_type not in CONFIG_MAPPING + + CONFIG_MAPPING[config_class.model_type] = config_class diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index b58cd1a..28aea56 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -10,18 +10,15 @@ import torch import torch.nn as nn from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config -from transformers import BloomConfig -from transformers.models.bloom.modeling_bloom import BloomAttention - -from petals.bloom.block import WrappedBloomBlock +from transformers import PretrainedConfig use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) def convert_block( - block: WrappedBloomBlock, - config: BloomConfig, + block: nn.Module, + config: PretrainedConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, load_in_8bit: bool, @@ -58,7 +55,7 @@ def convert_block( return block -def replace_8bit_linear(model: nn.Module, threshold=6.0): +def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module: """ A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): @@ -100,17 +97,22 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0): def make_tensor_parallel( - block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device -): - tp_config = get_bloom_config(model_config, devices) - del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] + block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device +) -> nn.Module: + if model_config.model_type == "bloom": + tp_config = get_bloom_config(model_config, devices) + del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] + else: + if len(devices) > 1: + logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution") + tp_config = None tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True) total_heads = 0 for tp_shard in tp_block.module_shards: for submodule in tp_shard.modules(): - if isinstance(submodule, BloomAttention): + if isinstance(submodule, model_config.attn_class): total_heads += submodule.num_heads - assert total_heads == model_config.n_head + assert total_heads == model_config.num_attention_heads return tp_block diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py index 3217e34..aefea1d 100644 --- a/src/petals/utils/disk_cache.py +++ b/src/petals/utils/disk_cache.py @@ -57,13 +57,16 @@ def free_disk_space_for( available_space = shutil.disk_usage(cache_dir).free - os_quota if max_disk_space is not None: available_space = min(available_space, max_disk_space - occupied_space) + + gib = 1024**3 + logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB") if size <= available_space: return revisions = [revision for repo in model_repos for revision in repo.revisions] revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified)) - # Remove as few least recently used blocks as possible + # Remove as few least recently used shards as possible pending_removal = [] freed_space = 0 extra_space_needed = size - available_space @@ -73,9 +76,8 @@ def free_disk_space_for( if freed_space >= extra_space_needed: break - gib = 1024**3 if pending_removal: - logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space") + logger.info(f"Removing {len(pending_removal)} shards to free {freed_space / gib:.1f} GiB of disk space") delete_strategy = cache_info.delete_revisions(*pending_removal) delete_strategy.execute() diff --git a/src/petals/utils/version.py b/src/petals/utils/version.py index f4a5be1..67b3866 100644 --- a/src/petals/utils/version.py +++ b/src/petals/utils/version.py @@ -1,3 +1,7 @@ +import os +import re +from typing import Union + import requests from hivemind.utils.logging import TextStyle, get_logger from packaging.version import parse @@ -7,7 +11,7 @@ import petals logger = get_logger(__name__) -def validate_version(): +def validate_version() -> None: logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}") try: r = requests.get("https://pypi.python.org/pypi/petals/json") @@ -24,3 +28,17 @@ def validate_version(): ) except Exception as e: logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True) + + +def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]: + if model_name_or_path is None: + return None + + match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path)) + if match is None: + return model_name_or_path + + logger.info( + f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones" + ) + return match.group(1) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 6909ccf..d42666b 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,7 +1,7 @@ import pytest import torch -from petals.client import DistributedBloomConfig +from petals import AutoDistributedConfig from petals.server.throughput import measure_compute_rps from test_utils import MODEL_NAME @@ -9,7 +9,7 @@ from test_utils import MODEL_NAME @pytest.mark.forked @pytest.mark.parametrize("tensor_parallel", [False, True]) def test_compute_throughput(tensor_parallel: bool): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( config, diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index a05387d..62c4e89 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -1,13 +1,10 @@ import random -from typing import Union import pytest import torch -from transformers.models.bloom.configuration_bloom import BloomConfig -from petals.bloom.block import WrappedBloomBlock -from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block -from petals.client import DistributedBloomConfig, RemoteSequential +from petals import DistributedBloomConfig, RemoteSequential +from petals.server.from_pretrained import load_pretrained_block from test_utils import * @@ -16,21 +13,22 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) - for block_index in random.sample(range(config.n_layer), 3): + for block_index in random.sample(range(config.num_hidden_layers), 3): remote_block = remote_sequential[block_index] inputs = torch.randn(1, 8, config.hidden_size) outputs_forward = remote_block(inputs) outputs_inference = [] - with remote_block.inference_session(max_length=inputs.shape[1]) as sess: - for i in range(inputs.shape[1]): - outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) - - # test that max length is respected - with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: - sess.step(inputs[:, -1:, :]) - assert "Maximum length exceeded" in repr(exc_info.value) + with torch.inference_mode(): + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: + for i in range(inputs.shape[1]): + outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + + # test that max length is respected + with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: + sess.step(inputs[:, -1:, :]) + assert "Maximum length exceeded" in repr(exc_info.value) outputs_inference = torch.cat(outputs_inference, dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) @@ -38,47 +36,3 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) - - -def _old_load_pretrained_block( - converted_model_name_or_path: str, - block_index: int, - torch_dtype: Union[torch.dtype, str] = "auto", -) -> WrappedBloomBlock: - """Load the BLOOM block by directly initializing the weights. - This test is used to check consistency with the previous implementation and can be removed in the future.""" - config = BloomConfig.from_pretrained(converted_model_name_or_path) - - block = WrappedBloomBlock(config) - state_dict = _load_state_dict( - converted_model_name_or_path, - block_index, - config, - cache_dir=None, - ) - - if torch_dtype == "auto": - with torch.no_grad(): - for name, param in block.named_parameters(): - assert name in state_dict, f"{name} not in state dict" - param.data = param.data.to(state_dict[name].dtype) - else: - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - block = block.to(dtype=torch_dtype) - - block.load_state_dict(state_dict, strict=True) - return block - - -@pytest.mark.forked -def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - torch.random.manual_seed(0) - inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype) - - block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) - ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) - - outputs = block.forward(inputs)[0] - outputs_ref = ref_block.forward(inputs)[0] - assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 15f3b5c..d20f654 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -7,9 +7,9 @@ import pytest import torch -from petals.bloom.from_pretrained import load_pretrained_block -from petals.client import DistributedBloomConfig +from petals import DistributedBloomConfig from petals.client.remote_sequential import RemoteSequential +from petals.server.from_pretrained import load_pretrained_block from test_utils import * diff --git a/tests/test_dtype.py b/tests/test_dtype.py index 03afd83..d102077 100644 --- a/tests/test_dtype.py +++ b/tests/test_dtype.py @@ -1,17 +1,16 @@ import pytest import torch -from petals.bloom.from_pretrained import load_pretrained_block -from petals.client import DistributedBloomConfig from petals.server.block_utils import resolve_block_dtype +from petals.server.from_pretrained import load_pretrained_block +from petals.utils.auto_config import AutoDistributedConfig from test_utils import MODEL_NAME @pytest.mark.forked @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"]) -def test_backend_dtype(torch_dtype): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype) - backend_dtype = resolve_block_dtype(config, torch_dtype) - other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype - assert backend_dtype == other_backend_dtype +def test_block_dtype(torch_dtype): + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) + block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype) + expected_dtype = resolve_block_dtype(config, torch_dtype) + assert all(param.dtype == expected_dtype for param in block.parameters()) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index cef002e..f2679f2 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -5,7 +5,7 @@ from hivemind import get_logger from transformers.generation import BeamSearchScorer from transformers.models.bloom import BloomForCausalLM -from petals.client.remote_model import DistributedBloomForCausalLM +from petals import DistributedBloomForCausalLM from test_utils import * logger = get_logger(__name__) @@ -20,7 +20,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato ) config = model.config assert isinstance(model, DistributedBloomForCausalLM) - assert len(model.transformer.h) == model.config.n_layer + assert len(model.transformer.h) == model.config.num_hidden_layers test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index d46ca1c..734683f 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -4,10 +4,10 @@ import torch.nn.functional as F from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind.proto import runtime_pb2 -from petals.bloom.from_pretrained import load_pretrained_block +from petals import DistributedBloomConfig from petals.client import RemoteSequenceManager, RemoteSequential -from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER +from petals.server.from_pretrained import load_pretrained_block from test_utils import * logger = get_logger(__name__) @@ -28,10 +28,10 @@ def test_remote_sequential(): full_grad = test_inputs.grad.clone() test_inputs.grad.data.zero_() - first_half = sequential[: config.n_layer // 2] - second_half = sequential[config.n_layer // 2 :] + first_half = sequential[: config.num_hidden_layers // 2] + second_half = sequential[config.num_hidden_layers // 2 :] assert len(first_half) + len(second_half) == len(sequential) - assert abs(len(first_half) - len(second_half)) == config.n_layer % 2 + assert abs(len(first_half) - len(second_half)) == config.num_hidden_layers % 2 for m in sequential, first_half, second_half: assert isinstance(repr(m), str) @@ -46,7 +46,7 @@ def test_remote_sequential(): assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3) # test RemoteSequential with lossy compression - block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] lossy_sequential = RemoteSequential( config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht) ) @@ -90,7 +90,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1) input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1) - intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True) + intermediate_prompts = torch.randn( + config.num_hidden_layers, batch_size, pre_seq_len, config.hidden_size, requires_grad=True + ) input_prompts = input_prompts.detach().requires_grad_(True) intermediate_prompts = intermediate_prompts.detach().requires_grad_(True) @@ -110,7 +112,7 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): assert intermediate_prompts_ref.grad is None outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1) - for block_index in range(config.n_layer): + for block_index in range(config.num_hidden_layers): block_prompt = intermediate_prompts_ref[block_index] outputs_ref[:, : block_prompt.shape[1]] += block_prompt diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 38e9a8a..86d04ca 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -5,8 +5,8 @@ import pytest import torch from hivemind import DHT, get_logger +from petals import DistributedBloomConfig from petals.client import RemoteSequenceManager, RemoteSequential -from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER from test_utils import * @@ -22,7 +22,7 @@ def test_sequence_manager_basics(mode: str): shutdown_evt = threading.Event() # test RemoteSequential with lossy compression - block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] sequential = RemoteSequential( config, sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt), diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 11d2565..5de3393 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -4,7 +4,7 @@ import hivemind import pytest import torch -from petals.client import DistributedBloomConfig, RemoteSequential +from petals import DistributedBloomConfig, RemoteSequential from petals.server.handler import CACHE_TOKENS_AVAILABLE from test_utils import * diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 84fcab4..408a261 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -6,7 +6,7 @@ import transformers from tensor_parallel import TensorParallel from tensor_parallel.slicing_configs import get_bloom_config -from petals.bloom.from_pretrained import load_pretrained_block +from petals.server.from_pretrained import load_pretrained_block from test_utils import MODEL_NAME