From 0b0277ed6f5497a6ad33e4e97aad360150734e53 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 30 Dec 2022 07:58:33 +0400 Subject: [PATCH 001/168] Add link to chat.petals.ml (#168) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 15bb1f5..f71bfcd 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,9 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache -- 💬 If you have any issues or feedback, please join [our Discord server](https://discord.gg/D9MwApKgWa)! -Check out more tutorials: +Check out more examples and tutorials: +- Chatbot web app: [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat) - Training a personified chatbot: [notebook](./examples/prompt-tuning-personachat.ipynb) - Fine-tuning BLOOM for text semantic classification: [notebook](./examples/prompt-tuning-sst2.ipynb) - Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) From 26e61202883f6e9418157d12d9facbffd7822bd6 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 30 Dec 2022 19:50:38 +0400 Subject: [PATCH 002/168] Fix code example in readme (#169) Makes it closer to runnable code, except for imports and defining tokenizer & data loader. --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index f71bfcd..1decf82 100644 --- a/README.md +++ b/README.md @@ -9,14 +9,14 @@ Generate text using distributed BLOOM and fine-tune it for your own tasks: ```python from petals import DistributedBloomForCausalLM +model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune", pre_seq_len=16) # Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet -model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune") inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"] outputs = model.generate(inputs, max_new_tokens=5) -print(tokenizer.decode(remote_outputs[0])) # A cat sat on a mat... +print(tokenizer.decode(outputs[0])) # A cat sat on a mat... -# Training (updates only prompts or adapters hosted locally) +# Fine-tuning (updates only prompts or adapters hosted locally) optimizer = torch.optim.AdamW(model.parameters()) for input_ids, labels in data_loader: outputs = model.forward(input_ids) @@ -34,13 +34,13 @@ Connect your own GPU and increase Petals capacity: ```bash # In an Anaconda env -(conda) $ conda install pytorch cudatoolkit=11.3 -c pytorch -(conda) $ pip install git+https://github.com/bigscience-workshop/petals -(conda) $ python -m petals.cli.run_server bigscience/bloom-petals +conda install pytorch cudatoolkit=11.3 -c pytorch +pip install git+https://github.com/bigscience-workshop/petals +python -m petals.cli.run_server bigscience/bloom-petals -# Or using a GPU-enabled Docker image -sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ - python -m petals.cli.run_server bigscience/bloom-petals +# Or using our GPU-enabled Docker image +sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm \ + learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals ``` 💬 If you have any issues or feedback, please join [our Discord server](https://discord.gg/D9MwApKgWa)! From 4014442a0fcd622b9ed0b3d8d3b9d8b052c62398 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 30 Dec 2022 22:42:07 +0300 Subject: [PATCH 003/168] Fix instruction for developers (#170) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1decf82..b459048 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ __System requirements:__ Petals only supports Linux for now. If you don't have a Petals uses pytest with a few plugins. To install them, run: ```python +conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch git clone https://github.com/bigscience-workshop/petals.git && cd petals pip install -e .[dev] ``` From ff8ade8d3b1c7a379b29baec5c03a6044ce0b3e1 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 30 Dec 2022 21:52:57 +0000 Subject: [PATCH 004/168] Bump version to 1.0.0 --- setup.cfg | 2 +- src/petals/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index e8fbaed..15cdfe0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = petals -version = 1.0alpha1 +version = 1.0.0 author = Petals Developers author_email = petals-dev@googlegroups.com description = Easy way to efficiently run 100B+ language models without high-end GPUs diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 9998543..667094c 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,6 +1,6 @@ from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.0alpha1" +__version__ = "1.0.0" _initialize_logs() From cdc3b6a25a8f6ed5d0d994f0c407a0c4d74be102 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 31 Dec 2022 02:22:40 +0400 Subject: [PATCH 005/168] Add PyPI badge, update instructions and links in readme (#172) --- README.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index b459048..121afe0 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@
Run 100B+ language models at home, BitTorrent-style.
Fine-tuning and inference up to 10x faster than offloading

+

Generate text using distributed BLOOM and fine-tune it for your own tasks: @@ -35,7 +36,7 @@ Connect your own GPU and increase Petals capacity: ```bash # In an Anaconda env conda install pytorch cudatoolkit=11.3 -c pytorch -pip install git+https://github.com/bigscience-workshop/petals +pip install -U petals python -m petals.cli.run_server bigscience/bloom-petals # Or using our GPU-enabled Docker image @@ -48,8 +49,8 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache -- Check out more examples and tutorials: - Chatbot web app: [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat) -- Training a personified chatbot: [notebook](./examples/prompt-tuning-personachat.ipynb) -- Fine-tuning BLOOM for text semantic classification: [notebook](./examples/prompt-tuning-sst2.ipynb) +- Training a personified chatbot: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) +- Fine-tuning BLOOM for text semantic classification: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) - Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Running a custom foundation model: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) @@ -92,12 +93,13 @@ Before building your own application that runs a language model with Petals, ple ## Installation Here's how to install Petals with conda: -``` -conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -pip install git+https://github.com/bigscience-workshop/petals + +```bash +conda install pytorch cudatoolkit=11.3 -c pytorch +pip install -U petals ``` -This script uses Anaconda to install cuda-enabled PyTorch. +This script uses Anaconda to install CUDA-enabled PyTorch. If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution). If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). @@ -108,8 +110,8 @@ __System requirements:__ Petals only supports Linux for now. If you don't have a Petals uses pytest with a few plugins. To install them, run: -```python -conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +```bash +conda install pytorch cudatoolkit=11.3 -c pytorch git clone https://github.com/bigscience-workshop/petals.git && cd petals pip install -e .[dev] ``` @@ -131,7 +133,7 @@ tail -f server1.log server2.log # view logs for both servers Then launch pytest: -``` +```bash export MODEL_NAME=bloom-testing/test-bloomd-560m-main REF_NAME=bigscience/bloom-560m export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v From 779959bc70a0270ceb46660db251087fe560337c Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 31 Dec 2022 02:51:52 +0400 Subject: [PATCH 006/168] Add link to PyPI (#173) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 121afe0..05f1ce2 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@
Run 100B+ language models at home, BitTorrent-style.
Fine-tuning and inference up to 10x faster than offloading

-
+

Generate text using distributed BLOOM and fine-tune it for your own tasks: From ae9e71fe8eb4e746c97852bd55977d943eba9df9 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 3 Jan 2023 18:35:51 +0300 Subject: [PATCH 007/168] Add local tensor-parallel fwd/bwd (#143) This pull request adds an option to run Petals server on multiple local GPUs. It uses https://github.com/BlackSamorez/tensor_parallel - 8bit approximation error same as in main (mean~=2% q0.9~=5%) - TP=1, 2, 3 (see screenshots above) - forward, grad w.r.t. input and inference exact match with main with TP=1 - `>=`80% GPU utilization with 3x 1080ti, batch = 8 tokens - throughput measured with and without TP - TP on 1080Tis has near-linear speedup comparable to the benchmarks (see first message) Co-authored-by: Iaroslav Lisniak Co-authored-by: Andrei Panferov Co-authored-by: Alexander Borzunov --- .github/workflows/run-tests.yaml | 6 +- setup.cfg | 1 + src/petals/cli/run_server.py | 6 +- src/petals/data_structures.py | 13 ++- src/petals/server/backend.py | 106 +++++++++++++++--------- src/petals/server/handler.py | 58 +++++-------- src/petals/server/memory_cache.py | 93 +++++++++++++-------- src/petals/server/server.py | 57 +++++++++---- src/petals/server/task_pool.py | 28 ++++--- src/petals/server/throughput.py | 33 ++++++-- src/petals/utils/convert_8bit.py | 39 --------- src/petals/utils/convert_block.py | 132 ++++++++++++++++++++++++++++++ tests/test_aux_functions.py | 11 ++- tests/test_block_exact_match.py | 2 +- tests/test_remote_sequential.py | 21 ++--- tests/test_tensor_parallel.py | 46 +++++++++++ 16 files changed, 450 insertions(+), 202 deletions(-) delete mode 100644 src/petals/utils/convert_8bit.py create mode 100644 src/petals/utils/convert_block.py create mode 100644 tests/test_tensor_parallel.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index aa8b114..af6299b 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -86,16 +86,16 @@ jobs: sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \ + python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \ --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log & SERVER3_PID=$! - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \ + python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \ --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log & SERVER4_PID=$! python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log & + --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu --torch_dtype float32 &> server5.log & SERVER5_PID=$! tail -n 100 -f server*.log & diff --git a/setup.cfg b/setup.cfg index 15cdfe0..effa114 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = protobuf>=3.20.3,<4.0dev speedtest-cli==2.1.3 hivemind==1.1.3 + tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 79c1b9d..e089937 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -129,8 +129,12 @@ def main(): parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained") parser.add_argument('--load_in_8bit', type=str, default=None, - help="Convert the loaded model into mixed-8bit quantized model. " + help="Convert the loaded transformer blocks into mixed-8bit quantized model. " "Default: True if GPU is available. Use `--load_in_8bit False` to disable this") + parser.add_argument("--tensor_parallel_devices", nargs='+', default=None, + help= + "Split each block between the specified GPUs such that each device holds a portion of every " + "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism") parser.add_argument("--skip_reachability_check", action='store_true', help="Skip checking this server's reachability via health.petals.ml " diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 919c8c1..d5a7181 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +import dataclasses from dataclasses import dataclass from enum import Enum -from typing import Any, Dict +from typing import Any, Dict, Tuple from hivemind import PeerID +from petals.server.memory_cache import Handle + ModuleUID = str UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention" CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" @@ -39,3 +44,9 @@ class RemoteSpanInfo: RPCInfo = Dict[str, Any] + + +@dataclasses.dataclass(frozen=True) +class InferenceMetadata: + prefix_length: int + cache_handles: Tuple[Handle, ...] diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 9aa4ea5..67b03c0 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -1,12 +1,19 @@ """Code for serving bloom blocks via hivemind-server""" +from __future__ import annotations + +from itertools import chain from typing import Any, Dict, Sequence, Tuple import torch -from hivemind import BatchTensorDescriptor +from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger +from tensor_parallel import TensorParallel +from tensor_parallel.tensor_parallel import PerDeviceTensors +from transformers import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention -from petals.bloom.block import WrappedBloomBlock +from petals.data_structures import InferenceMetadata from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy @@ -17,9 +24,10 @@ logger = get_logger(__file__) class TransformerBackend(ModuleBackend): """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference""" - def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): + def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): super().__init__(*args, **kwargs) - assert isinstance(self.module, WrappedBloomBlock) + assert isinstance(self.module, TensorParallel) + self.config = config self.memory_cache = memory_cache for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" @@ -27,18 +35,26 @@ class TransformerBackend(ModuleBackend): assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" max_batch_size = self.forward_pool.max_batch_size + device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( - self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference" + self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" ) self.forward_pool = PrioritizedTaskPool( - self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward" + self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" ) self.backward_pool = PrioritizedTaskPool( - self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward" + self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" ) assert backend_dtype is not None self.dtype = backend_dtype + self.shard_num_heads = [] + for shard in self.module.module_shards: + for submodule in shard.modules(): + if isinstance(submodule, BloomAttention): + self.shard_num_heads.append(submodule.num_heads) + assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head + self.inference_schema = ( ( *self.args_schema, @@ -48,44 +64,60 @@ class TransformerBackend(ModuleBackend): self.kwargs_schema, ) + def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: + """Create tensor descriptors for attention cache tensors used during inference_step""" + head_dim = self.config.hidden_size // self.config.n_head + cache_tensors = [] + for device, num_heads in zip(self.module.devices, self.shard_num_heads): + keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) + values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device) + cache_tensors.extend((keys, values)) + return cache_tensors + def inference_step( - self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor + self, + hidden_states: torch.Tensor, + hypo_ids: torch.LongTensor, + inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: - num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim with torch.inference_mode(): assert ( hidden_states.ndim == 3 ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - cache_handle, rel_index, prefix_length = map(int, cache_metadata[0]) - - with self.memory_cache.use_cache(cache_handle) as cache: - batch_size = cache.shape[2] - max_length = cache.shape[-1] // (head_dim * num_heads) - assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4 - if not is_dummy(hypo_ids): - assert hypo_ids.shape[0] == batch_size - cache[rel_index, :, :] = cache[rel_index, :, hypo_ids] # in-place reorder cache by hypo ids - key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length) - value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim) - - key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] - value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim] - logger.debug( - f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}" - ) - hidden_states, (new_key, new_value) = self.module.forward( - hidden_states, layer_past=(key_past, value_past), use_cache=True - ) - new_length = new_key.shape[-1] - assert new_length > prefix_length - assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0] - assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length - new_key = new_key.view(batch_size, num_heads, head_dim, -1) - new_value = new_value.view(batch_size, num_heads, -1, head_dim) - key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] - value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] + with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + self._reorder_cache_inplace(cache_tensors, hypo_ids) + layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) + hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) + self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) return (hidden_states,) + def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor): + """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" + if not is_dummy(hypo_ids): + for cache_tensor in cache_tensors: + cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids + + def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]: + """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past""" + key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2]) + for i in range(len(key_cache)): + key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] + value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim] + layer_past = tuple(chain(*zip(key_cache, value_cache))) + return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past + + def _update_cache_inplace( + self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int + ): + """Writes new key/value tensors back into cache, works in-place""" + _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape + for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]): + new_key = new_key.view(*cache_key.shape[:3], new_length) + cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] + for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]): + new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim) + cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] + def get_pools(self) -> Sequence[PrioritizedTaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index ff66e4b..387431a 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio import contextlib +from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch @@ -8,10 +11,10 @@ from hivemind import ( DHT, MSGPackSerializer, P2PContext, - TensorDescriptor, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, + nested_pack, serialize_torch_tensor, ) from hivemind.moe.server.connection_handler import ConnectionHandler @@ -21,8 +24,9 @@ from hivemind.utils.asyncio import amap_in_executor, anext from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming -from petals.data_structures import CHAIN_DELIMITER, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID from petals.server.backend import TransformerBackend +from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from petals.utils.misc import DUMMY, is_dummy @@ -122,17 +126,12 @@ class TransformerConnectionHandler(ConnectionHandler): point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 - - cache_metadata = torch.tensor( - [[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64 - ) # [cache_handle, rel_index, prefix_length] prefix_length = 0 - async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle: + async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: + assert len(cache_handles) == len(requested_backends) while request.tensors: # iterate while user is willing to supply tensors - hidden_states, prompts, hypo_ids = [ - deserialize_torch_tensor(tensor) for tensor in request.tensors - ] + hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) @@ -155,16 +154,14 @@ class TransformerConnectionHandler(ConnectionHandler): ) # run request tensors through all requested modules, update caches - for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)): + for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt if hidden_states.numel() == 0: continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. # when user wants to pre-allocate cache or check that server *can* allocate that cache - cache_metadata[:] = torch.tensor( - [cache_handle, rel_index, prefix_length], dtype=torch.int64 - ) + metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles)) assert isinstance( hidden_states, torch.Tensor ), f"hidden states must be tensor, got {type(hidden_states)}" @@ -175,7 +172,6 @@ class TransformerConnectionHandler(ConnectionHandler): backend.inference_pool, PrioritizedTaskPool ), "petals support only prioritized pools" priority = self._prioritizer.prioritize( - cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends), @@ -183,7 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler): type="inference", ) (hidden_states,) = await backend.inference_pool.submit_task( - hidden_states, hypo_ids, cache_metadata, priority=priority + hidden_states, hypo_ids, metadata, priority=priority ) # serialize and send last layer outputs @@ -355,28 +351,14 @@ class TransformerConnectionHandler(ConnectionHandler): @contextlib.asynccontextmanager async def _allocate_cache( self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int - ) -> Sequence[int]: - """Allocate memory cache for all transformer blocks, return cache handle""" - - n_blocks = len(backends) - backend = backends[0] - n_heads = backend.module.self_attention.num_heads - head_dim = backend.module.self_attention.head_dim - descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype) - alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 - - gib = 1024**3 - cur_size = backend.memory_cache.current_size_bytes - max_size = backend.memory_cache.max_size_bytes - friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" - logger.info( - f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), " - f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" - ) - - async with backend.memory_cache.allocate_cache(descr) as handle: - logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)") - yield handle + ) -> Sequence[Sequence[Handle, ...]]: + """ + Allocate memory cache for all transformer blocks, return cache handle + :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend + """ + descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends] + async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles: + yield nested_pack(handles, descriptors) def _log_request( self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 53c1a7d..0e39cf5 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -10,7 +10,7 @@ import ctypes import multiprocessing as mp import os import time -from typing import AsyncContextManager, Dict, Optional, Union +from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple import hivemind import torch @@ -26,10 +26,9 @@ Handle = int class MemoryCache: """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" - def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float): + def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float): self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) self.alloc_timeout = alloc_timeout - self.device = device self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) @@ -57,26 +56,48 @@ class MemoryCache: self._handle_counter.value = value @contextlib.asynccontextmanager - async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]: + async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]: """ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed. - :param descr: allocate a tensor of this size, dtype, etc + :param descriptors: one or more tensors tensor of this size, dtype, etc + + :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices; + if not, it will count maximum tensor allocation across devices for the purposes of size limit :note: This function should be called by connection handlers, it can be called concurrently from multiple processes. Furthermore, it can be called concurrently with at most one use_cache call in runtime. """ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime" - assert descr.device is None and descr - - alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 - alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr)) + assert all(descr.device is not None for descr in descriptors), "please specify allocated devices" + max_alloc_size = self.get_allocation_size(*descriptors) + + gib = 1024**3 + cur_size, max_size = self.current_size_bytes, self.max_size_bytes + friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" + logger.info( + f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), " + f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" + ) + + alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors)) try: - yield await shield_and_wait(alloc_task) + handles = await shield_and_wait(alloc_task) + logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") + yield handles finally: - await shield_and_wait(self._schedule_free(alloc_size, alloc_task)) - - async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle: + await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) + + @staticmethod + def get_allocation_size(*descriptors: TensorDescriptor) -> int: + """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum""" + alloc_size_by_device = {} + for descr in descriptors: + tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 + alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size + return max(alloc_size_by_device.values()) + + async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]: """ This method should be called inside asyncio.shield() because: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation @@ -87,11 +108,11 @@ class MemoryCache: if self.current_size_bytes + alloc_size > self.max_size_bytes: await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) async with hivemind.utils.enter_asynchronously(self._lock_metadata): - handle = int(self.handle_counter) + handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) self.current_size_bytes += alloc_size - self.handle_counter += 1 # note: this will eventually overflow and it is okay - self._pipe_send.send((handle, descr)) - return handle + self.handle_counter += len(handles) # note: this will eventually overflow and it is okay + self._pipe_send.send((handles, descriptors)) + return handles async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task): """ @@ -102,10 +123,10 @@ class MemoryCache: if alloc_task.exception() is not None: return - handle = alloc_task.result() + handles = alloc_task.result() async with hivemind.utils.enter_asynchronously(self._lock_metadata): - self._pipe_send.send((handle, None)) # signal runtime to free that handle + self._pipe_send.send((handles, None)) # signal runtime to free these handles self.current_size_bytes -= alloc_size self._memory_freed_event.set() @@ -125,11 +146,11 @@ class MemoryCache: self._memory_freed_event.clear() @contextlib.contextmanager - def use_cache(self, handle: Handle) -> torch.Tensor: + def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]: """ - Return a tensor that was previously allocated with try_allocate_cache, + Return one or more tensors previously allocated with allocate_cache, - :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism. + :note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism. However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache """ assert os.getpid() == self.runtime_pid @@ -138,20 +159,20 @@ class MemoryCache: with self._lock_metadata: # read creation/deletion requests from connection handlers while self._pipe_recv.poll(): - recv_handle, recv_data = self._pipe_recv.recv() - if isinstance(recv_data, TensorDescriptor): - self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device) - elif recv_data is None: - if recv_handle not in self._allocated_tensors: - logger.warning( - f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle" - ) - self._allocated_tensors.pop(recv_handle, None) - else: - logger.error(f"MemoryCache pipe received unexpected message: {recv_data}") - - assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" - yield self._allocated_tensors[handle] + recv_handles, recv_data = self._pipe_recv.recv() + if recv_data is not None: # create new tensors + assert len(recv_handles) == len(recv_data) + for handle, descr in zip(recv_handles, recv_data): + self._allocated_tensors[handle] = descr.make_zeros() + assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" + else: # delete tensors by handle + for handle in recv_handles: + if handle not in self._allocated_tensors: + logger.warning( + f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" + ) + self._allocated_tensors.pop(handle, None) + yield tuple(self._allocated_tensors[handle] for handle in handles) class AllocationFailed(Exception): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index f7006cc..a8927aa 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -6,7 +6,7 @@ import multiprocessing as mp import random import threading import time -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union import numpy as np import psutil @@ -29,7 +29,7 @@ from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.throughput import get_host_throughput -from petals.utils.convert_8bit import replace_8bit_linear +from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__file__) @@ -76,6 +76,7 @@ class Server: mean_block_selection_delay: float = 2.5, use_auth_token: Optional[str] = None, load_in_8bit: Optional[bool] = None, + tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, **kwargs, ): @@ -128,6 +129,8 @@ class Server: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) + if device.type == "cuda" and device.index is None: + device = torch.device(device.type, index=0) self.device = device if isinstance(torch_dtype, str): @@ -141,6 +144,13 @@ class Server: logger.info("Model weights will be loaded in 8-bit format") self.load_in_8bit = load_in_8bit + if tensor_parallel_devices is None: + tensor_parallel_devices = (device,) + self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices)) + if len(self.tensor_parallel_devices) > 1: + logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}") + check_device_balance(self.tensor_parallel_devices) + assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() @@ -174,6 +184,7 @@ class Server: device, torch_dtype, load_in_8bit=load_in_8bit, + tensor_parallel_devices=self.tensor_parallel_devices, force_eval=(throughput == "eval"), cache_dir=cache_dir, ) @@ -214,13 +225,28 @@ class Server: self.converted_model_name_or_path == "bigscience/bloom-petals" ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually" assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually" + num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 + + if num_devices > 1: + memory_per_device = tuple( + torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices + ) + total_memory = min(memory_per_device) * num_devices + if max(memory_per_device) / min(memory_per_device) > 1.5: + raise ValueError( + "GPU devices have highly uneven memory, which makes tensor parallelism inefficient. " + "Please launch individual servers on each GPU or set --num_blocks manually to " + "override this exception." + ) + else: + total_memory = torch.cuda.get_device_properties(self.device).total_memory - total_memory = torch.cuda.get_device_properties(self.device).total_memory block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit) gib = 1024**3 - attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size + attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size - num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block)) + autograd_memory = 2 * gib * num_devices # gpu memory used for intermediate tensors in rpc_backward + num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" logger.info( @@ -260,6 +286,7 @@ class Server: sender_threads=self.sender_threads, use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, + tensor_parallel_devices=self.tensor_parallel_devices, start=True, ) try: @@ -352,6 +379,7 @@ class ModuleContainer(threading.Thread): expiration: Optional[float], use_auth_token: Optional[str], load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], **kwargs, ) -> ModuleContainer: module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] @@ -367,7 +395,9 @@ class ModuleContainer(threading.Thread): joining_announcer.start() logger.info(f"Announced that blocks {block_indices} are joining") - memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout) + assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) + + memory_cache = MemoryCache(attn_cache_size, alloc_timeout) blocks = {} try: for module_uid, block_index in zip(module_uids, block_indices): @@ -380,18 +410,13 @@ class ModuleContainer(threading.Thread): cache_dir=cache_dir, max_disk_space=max_disk_space, ) + block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) - if load_in_8bit: - block = replace_8bit_linear(block) - - block = block.to(device) - for param in block.parameters(): - param.requires_grad = False - - backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype + backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype blocks[module_uid] = TransformerBackend( module_uid, block, + config=block_config, memory_cache=memory_cache, backend_dtype=backend_dtype, args_schema=( @@ -451,6 +476,7 @@ class ModuleContainer(threading.Thread): request_timeout: float, session_timeout: float, step_timeout: float, + device: Union[str, torch.device], start: bool, **kwargs, ): @@ -469,7 +495,8 @@ class ModuleContainer(threading.Thread): ) for _ in range(num_handlers) ] - self.runtime = Runtime(self.module_backends, **kwargs) + self.runtime = Runtime(self.module_backends, device=None, **kwargs) + # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. self.online_announcer = ModuleAnnouncerThread( list(self.module_backends.keys()), dht, diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index 1374f94..330679c 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -5,7 +5,7 @@ import time from concurrent.futures._base import PENDING from dataclasses import dataclass, field from queue import PriorityQueue -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from hivemind import get_logger @@ -43,6 +43,7 @@ class PrioritizedTaskPool(TaskPoolBase): :param name: pool name, used for logging :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more + :param device: if specified, input tensors will be moved to that device by default :param start: if True, start automatically at the end of __init__ """ @@ -52,11 +53,13 @@ class PrioritizedTaskPool(TaskPoolBase): max_batch_size: int, name: str, min_batch_size=1, + device: Optional[torch.device] = None, daemon=True, start=False, ): super().__init__(process_func, daemon=daemon, name=name) self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size + self.device = device self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime @@ -101,7 +104,7 @@ class PrioritizedTaskPool(TaskPoolBase): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") self.terminate() - def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture: + def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: """Add task to this pool's queue, return Future for its output""" future = MPFuture() # Remove shmem from MPFuture. This disables the .cancel() feature but @@ -129,10 +132,9 @@ class PrioritizedTaskPool(TaskPoolBase): self, timeout: Optional[float] = None, device: Optional[torch.device] = None ) -> Tuple[Any, List[torch.Tensor]]: """receive next batch of arrays""" + device = device if device is not None else self.device task = self._ordered_tasks.get(block=True, timeout=timeout) - batch_inputs = [ - tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args - ] + batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args] self._dispatched_tasks[task.uid] = task self.batch_receiver.recv() # reduce the number of active batches if not self._ordered_tasks.empty(): @@ -142,11 +144,7 @@ class PrioritizedTaskPool(TaskPoolBase): def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]): """send results for a processed batch, previously loaded through load_batch_to_runtime""" - batch_outputs = [ - tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad) - for tensor in batch_outputs - ] - + batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs] task = self._dispatched_tasks.pop(uid, None) if task is None: logger.error( @@ -182,3 +180,13 @@ class PrioritizedTaskPool(TaskPoolBase): assert len(item) == 2 self._priority.value = float(item[0]) self._oldest_undispatched_timestamp.value = float(item[1]) + + +def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False): + if isinstance(arg, torch.Tensor): + arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad) + # note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor + # produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed) + if share_memory: + arg = arg.share_memory_() + return arg diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index c24e710..73ad973 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -2,9 +2,10 @@ import fcntl import json import os import time +from collections import Counter from hashlib import sha256 from pathlib import Path -from typing import Optional, Union +from typing import Optional, Sequence, Union import torch from hivemind.utils.logging import get_logger @@ -12,7 +13,7 @@ from transformers import BloomConfig from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_8bit import replace_8bit_linear +from petals.utils.convert_block import convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__file__) @@ -37,6 +38,7 @@ def get_host_throughput( dtype: Union[str, torch.dtype], *, load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, cache_dir: Optional[str] = None, ) -> float: @@ -57,6 +59,9 @@ def get_host_throughput( cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}" cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}" + if len(tensor_parallel_devices) > 1: + for i, device_i in enumerate(tensor_parallel_devices): + cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}" cache = {} try: @@ -69,7 +74,9 @@ def get_host_throughput( cache = {} if cache_key not in cache: - cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit) + cache[cache_key] = measure_throughput_info( + config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + ) try: os.makedirs(cache_path.parent, exist_ok=True) @@ -87,6 +94,7 @@ def measure_throughput_info( dtype: torch.dtype, *, load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], ) -> float: """Measure network and compute throughput in forward pass tokens per second""" @@ -95,7 +103,9 @@ def measure_throughput_info( ) return min( measure_network_rps(config), - measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit), + measure_compute_rps( + config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + ), ) @@ -129,14 +139,15 @@ def measure_compute_rps( dtype: torch.dtype, *, load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], n_tokens: int = 16, n_steps: int = 500, ) -> float: + if not tensor_parallel_devices: + tensor_parallel_devices = (device,) with torch.inference_mode(): block = WrappedBloomBlock(config).to(dtype) - if load_in_8bit: - block = replace_8bit_linear(block) - block = block.to(device) + block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) cache = None elapsed = 0 @@ -149,9 +160,13 @@ def measure_compute_rps( elapsed += time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed + devices_repr = get_device_name(device) + if len(tensor_parallel_devices) > 1: + device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices))) + devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) + logger.info( - f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): " - f"{device_rps:.1f} RPS" + f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS" ) return device_rps diff --git a/src/petals/utils/convert_8bit.py b/src/petals/utils/convert_8bit.py deleted file mode 100644 index eeb29e7..0000000 --- a/src/petals/utils/convert_8bit.py +++ /dev/null @@ -1,39 +0,0 @@ -import bitsandbytes as bnb -import torch - -from petals.utils.linear8bitlt_patch import CustomLinear8bitLt - - -def replace_8bit_linear(model, threshold=6.0): - """ - A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` - library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): - 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA - version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ - bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) - The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should - be kept as a `torch.nn.Linear` module. - Parameters: - model (`torch.nn.Module`): - Input model or `torch.nn.Module` as the function is run recursively. - threshold (`float`, *optional*): - `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to - `6.0` as described by the paper. - """ - for n, module in model.named_children(): - if len(list(module.children())) > 0: - replace_8bit_linear(module, threshold) - - if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: - model._modules[n] = CustomLinear8bitLt( - module.in_features, - module.out_features, - module.bias is not None, - has_fp16_weights=False, - threshold=threshold, - ) - model._modules[n].weight = bnb.nn.Int8Params( - module.weight.data, requires_grad=False, has_fp16_weights=False - ).to(module.weight.dtype) - model._modules[n].bias = module.bias - return model diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py new file mode 100644 index 0000000..794ecd9 --- /dev/null +++ b/src/petals/utils/convert_block.py @@ -0,0 +1,132 @@ +""" +Tools for converting transformer blocks, applying quantization and/or tensor parallelism +""" +import re +from typing import Sequence + +import bitsandbytes as bnb +import tensor_parallel as tp +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from tensor_parallel.slicing_configs import get_bloom_config +from transformers import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention + +from petals.bloom.block import WrappedBloomBlock +from petals.utils.linear8bitlt_patch import CustomLinear8bitLt + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +def convert_block( + block: WrappedBloomBlock, + config: BloomConfig, + tensor_parallel_devices: Sequence[torch.device], + output_device: torch.device, + load_in_8bit: bool, + threshold: float = 6.0, + freeze: bool = True, +) -> tp.TensorParallel: + """ + Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization + + :note: some optimizations will modify the input block in-place! + :param block: a single transformer block, either pre-trained or newly initialized + :param config: HF transformers config for the full model + :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices + :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity) + :param output_device: if tensor_parallel_devices is True, output + :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint + :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 ) + :param freeze: if True (default), make all module parameters non-trainable + :return: a module that acts like the original block, but runs with all specified optimizations + + """ + if freeze: + for param in block.parameters(): + param.requires_grad = False + + block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) + + if load_in_8bit: + block = replace_8bit_linear(block, threshold=threshold) + + for shard, device in zip(block.module_shards, block.devices): + shard.to(device) + + return block + + +def replace_8bit_linear(model: nn.Module, threshold=6.0): + """ + A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should + be kept as a `torch.nn.Linear` module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + threshold (`float`, *optional*): + `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to + `6.0` as described by the paper. + """ + for n, module in model.named_children(): + if len(list(module.children())) > 0: + replace_8bit_linear(module, threshold) + + if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: + assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}" + model._modules[n] = CustomLinear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=threshold, + ) + model._modules[n].weight = bnb.nn.Int8Params( + module.weight.data, requires_grad=False, has_fp16_weights=False + ).to(module.weight.dtype) + model._modules[n].bias = module.bias + return model + + +def make_tensor_parallel( + block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device +): + assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt)) + tp_config = get_bloom_config(model_config, devices) + del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] + tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True) + total_heads = 0 + for tp_shard in tp_block.module_shards: + for submodule in tp_shard.modules(): + if isinstance(submodule, BloomAttention): + total_heads += submodule.num_heads + assert total_heads == model_config.n_head + return tp_block + + +def check_device_balance(devices: Sequence[torch.device]): + if not all(device.type == "cuda" for device in devices): + logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk") + return + unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices)) + if len(unique_device_capabilities) > 1: + logger.warning( + f"Found GPUs with uneven capabilities: {unique_device_capabilities}. " + f"Using GPUs with different performance will cause the server to wait for the slowest GPU." + ) + + memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices) + used_memory = min(memory_per_device) * len(memory_per_device) + wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device) + if wasted_memory_rate > 0.05: + logger.warning( + f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. " + f"Consider running high-memory GPUs in a separate server." + ) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 46c4bfe..554127f 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -7,10 +7,17 @@ from petals.server.throughput import measure_compute_rps, measure_network_rps @pytest.mark.forked -def test_throughput_basic(): +@pytest.mark.parametrize("tensor_parallel", [False, True]) +def test_throughput_basic(tensor_parallel: bool): config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( - config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10 + config, + device=torch.device("cpu"), + dtype=torch.bfloat16, + load_in_8bit=False, + tensor_parallel_devices=tensor_parallel_devices, + n_steps=10, ) assert isinstance(compute_rps, float) and compute_rps > 0 network_rps = measure_network_rps(config) diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index ab41ce8..664f255 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -13,7 +13,7 @@ from petals.dht_utils import get_remote_module @pytest.mark.forked -def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): +def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) config = DistributedBloomConfig.from_pretrained(MODEL_NAME) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index ed76696..a8e585f 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -1,6 +1,7 @@ import pytest import torch -from hivemind import DHT, BatchTensorDescriptor, get_logger +import torch.nn.functional as F +from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler from hivemind.proto import runtime_pb2 from test_utils import * @@ -39,10 +40,10 @@ def test_remote_sequential(): assert hidden.shape == test_inputs.shape assert hidden.requires_grad second_half_outputs = second_half(hidden) - assert torch.allclose(second_half_outputs, full_outputs) + assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4) (second_half_outputs * grad_proj).sum().backward() - assert torch.allclose(test_inputs.grad, full_grad) + assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3) # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] @@ -58,7 +59,7 @@ def test_remote_sequential(): assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used" assert abs(approx_outputs - full_outputs).mean() < 0.01 absmax = abs(full_grad).max() - assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01 + assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05 class DummyCustomSequenceManager(RemoteSequenceManager): @@ -87,9 +88,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) remote_sequential = RemoteSequential(config, dht) - inputs = torch.randn(batch_size, seq_len, config.hidden_size) - output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size) - input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True) + inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) + output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1) + input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1) intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True) input_prompts = input_prompts.detach().requires_grad_(True) @@ -117,10 +118,10 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32) (outputs_ref,) = block(outputs_ref) - assert torch.allclose(outputs_ref, outputs) + assert torch.allclose(outputs_ref, outputs, atol=1e-3) (outputs_ref * output_proj).sum().backward() assert input_prompts_ref.grad is not None - assert torch.allclose(input_prompts_ref.grad, input_prompts.grad) + assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2) assert intermediate_prompts_ref.grad is not None - assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad) + assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py new file mode 100644 index 0000000..40eb1ee --- /dev/null +++ b/tests/test_tensor_parallel.py @@ -0,0 +1,46 @@ +import random + +import pytest +import torch +import transformers +from tensor_parallel import TensorParallel +from tensor_parallel.slicing_configs import get_bloom_config +from test_utils import MODEL_NAME + +from petals.bloom.from_pretrained import load_pretrained_block + + +@pytest.mark.forked +@pytest.mark.parametrize("custom_config", [True, False]) +@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4]) +def test_tp_block(devices, custom_config): + block_index = random.randint(0, 10) + model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME) + block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0]) + + tp_config = None + if custom_config: + tp_config = get_bloom_config(model_config, devices) + + batch_size = 2 + prefix_length = 5 + + test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0]) + test_inputs2 = test_inputs1.detach().clone().requires_grad_(True) + test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0]) + test_prefix2 = test_prefix1.detach().clone().requires_grad_(True) + grad_proj = torch.rand_like(test_inputs1) + + y_prefix_ref, layer_past = block(test_prefix1, use_cache=True) + y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past) + y_ref.backward(grad_proj) + + block_tp = TensorParallel(block, devices, config=tp_config) + y_prefix, layer_past = block_tp(test_prefix2, use_cache=True) + y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past) + y_ours.backward(grad_proj) + + assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6) + assert torch.allclose(y_ours, y_ref, atol=1e-6) + assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4) + assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4) From 356e099c3de25d672983c646ea39c688c915f059 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 4 Jan 2023 02:18:45 +0400 Subject: [PATCH 008/168] Make Docker command more visible (#175) --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 05f1ce2..f57a001 100644 --- a/README.md +++ b/README.md @@ -31,15 +31,17 @@ for input_ids, labels in data_loader: 🚀  Try now in Colab

-Connect your own GPU and increase Petals capacity: +Connect your own GPU and increase Petals capacity — run this in an [Anaconda](https://www.anaconda.com) env: ```bash -# In an Anaconda env conda install pytorch cudatoolkit=11.3 -c pytorch pip install -U petals python -m petals.cli.run_server bigscience/bloom-petals +``` + +Or use our [Docker](https://www.docker.com) image: -# Or using our GPU-enabled Docker image +```bash sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm \ learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals ``` From 6948a0c5ee9976ac1165a5f8909056a17bf03019 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 4 Jan 2023 06:42:03 +0400 Subject: [PATCH 009/168] Allow to disable chunked forward (#176) --- src/petals/bloom/modeling_utils.py | 7 +++++-- src/petals/client/remote_model.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py index 6847423..366c035 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/bloom/modeling_utils.py @@ -45,8 +45,11 @@ class LMHead(nn.Module): def forward(self, hidden_states): word_embeddings = self.word_embeddings.weight - # We use 'chunked_forward' only when embeddings are in half-precision on CPU. - if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu": + if ( + self.chunk_size is not None + and word_embeddings.dtype in [torch.float16, torch.bfloat16] + and word_embeddings.device.type == "cpu" + ): lm_logits = self.chunked_forward(hidden_states) else: # Switch dtype in case word_embeddings are fp16/bf16 diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 52feb22..c6c46ee 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -34,7 +34,8 @@ class DistributedBloomConfig(BloomConfig): dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) daemon_startup_timeout: int = 30 dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models - chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU + chunk_size_for_efficient_fp16_on_cpu: Optional[int] = 10000 + # Chunk size for efficient half-precision on CPU in the LM head. Set to None if your CPU works fast with bfloat16. pre_seq_len: int = 0 # a number of tokens for prompt tuning. tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters'] request_timeout: int = 30 # a number of seconds for waiting result from each node From 55698381d0cf0c3b3f0ab33834e183eea19bb1ba Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 4 Jan 2023 23:28:16 +0400 Subject: [PATCH 010/168] Disable chunked_forward() on AVX512 CPUs (#179) --- setup.cfg | 1 + src/petals/bloom/modeling_utils.py | 33 ++++++++++++++++++++++-------- src/petals/client/remote_model.py | 12 +++++++---- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/setup.cfg b/setup.cfg index effa114..8c7b19a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ install_requires = tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 + cpufeature>=0.2.0 [options.extras_require] dev = diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py index 366c035..4e2899c 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/bloom/modeling_utils.py @@ -4,9 +4,11 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e See commit history for authorship. """ +import psutil import torch import torch.nn.functional as F import torch.utils.checkpoint +from cpufeature import CPUFeature from hivemind import get_logger from torch import nn from transformers import BloomConfig @@ -24,7 +26,14 @@ class LMHead(nn.Module): def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding): super().__init__() self.word_embeddings = word_embeddings - self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu + + self.use_chunked_forward = config.use_chunked_forward + if self.use_chunked_forward == "auto": + # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward(). + # Otherwise, it's ~8x slower. + self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]) + self.chunked_forward_step = config.chunked_forward_step + self._bf16_warning_shown = False @property def in_features(self) -> int: @@ -46,9 +55,9 @@ class LMHead(nn.Module): word_embeddings = self.word_embeddings.weight if ( - self.chunk_size is not None - and word_embeddings.dtype in [torch.float16, torch.bfloat16] + word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu" + and self.use_chunked_forward ): lm_logits = self.chunked_forward(hidden_states) else: @@ -59,9 +68,17 @@ class LMHead(nn.Module): def chunked_forward(self, hidden_states): """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. - chunk_size: provides trade-off between efficiency and extra memory consumption. + chunked_forward_step: provides trade-off between efficiency and extra memory consumption. """ - assert self.chunk_size > 0, "Chunk size for chunked forward must be positive" + assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" + + if not self._bf16_warning_shown: + if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: + logger.warning( + "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. " + "Consider loading the model with torch_dtype='float32'" + ) + self._bf16_warning_shown = True word_embeddings = self.word_embeddings.weight num_embeddings = self.word_embeddings.num_embeddings @@ -69,7 +86,7 @@ class LMHead(nn.Module): hidden_states = hidden_states.float() output = torch.empty(*hidden_states.shape[:-1], num_embeddings) - for i in range(0, num_embeddings, self.chunk_size): - chunk = word_embeddings[i : i + self.chunk_size].float() - output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk) + for i in range(0, num_embeddings, self.chunked_forward_step): + chunk = word_embeddings[i : i + self.chunked_forward_step].float() + output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk) return output diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index c6c46ee..3e52e40 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from typing import List, Optional +from typing import List, Optional, Union import hivemind import torch @@ -34,11 +34,15 @@ class DistributedBloomConfig(BloomConfig): dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) daemon_startup_timeout: int = 30 dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models - chunk_size_for_efficient_fp16_on_cpu: Optional[int] = 10000 - # Chunk size for efficient half-precision on CPU in the LM head. Set to None if your CPU works fast with bfloat16. + request_timeout: int = 30 # a number of seconds for waiting result from each node + pre_seq_len: int = 0 # a number of tokens for prompt tuning. tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters'] - request_timeout: int = 30 # a number of seconds for waiting result from each node + + # This settings matter for running the client with dtype bfloat16 on CPU. + # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. + use_chunked_forward: Union[str, bool] = "auto" + chunked_forward_step: int = 16384 original_register_parameter = nn.Module.register_parameter From e27706358c62f47b3d18bfed908678db05b49ec0 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 5 Jan 2023 09:34:03 +0400 Subject: [PATCH 011/168] Use slightly less memory in .generate() (#177) --- README.md | 2 +- src/petals/client/remote_generation.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index f57a001..e1eeb7f 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache -- Check out more examples and tutorials: -- Chatbot web app: [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat) +- Chatbot web app (connects to Petals via an HTTP endpoint): [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat) - Training a personified chatbot: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) - Fine-tuning BLOOM for text semantic classification: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) - Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 053e209..4182ca8 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -40,7 +40,7 @@ class RemoteGenerationMixin: return self.transformer.h.inference_session(**kwargs) - @torch.no_grad() + @torch.inference_mode() def generate( self, inputs: Optional[torch.Tensor] = None, @@ -171,13 +171,15 @@ class RemoteGenerationMixin: seq_idx = outputs[0].size(1) hypo_ids = torch.arange(outputs[0].size(0)) while True: - embs = self.transformer.word_embeddings(outputs[-1]) + hidden_state = self.transformer.word_embeddings(outputs[-1]) intermediate_prompts = None if self.config.pre_seq_len > 0 and len(outputs) == 1: - prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0)) - embs = torch.cat([prompts, embs], dim=1) - embs = self.transformer.word_embeddings_layernorm(embs) - hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1] + prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0)) + hidden_state = torch.cat([prompts, hidden_state], dim=1) + hidden_state = self.transformer.word_embeddings_layernorm(hidden_state) + + hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1] + hidden_state = self.transformer.ln_f(hidden_state) lm_logits = self.lm_head(hidden_state) From 6dd9a938bd38f821f1e2a01b836a58c9a9a774f3 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 5 Jan 2023 10:34:52 +0400 Subject: [PATCH 012/168] Import bitsandbytes only if it's going to be used (#180) --- src/petals/utils/convert_block.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 794ecd9..0afe641 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -4,7 +4,6 @@ Tools for converting transformer blocks, applying quantization and/or tensor par import re from typing import Sequence -import bitsandbytes as bnb import tensor_parallel as tp import torch import torch.nn as nn @@ -14,7 +13,6 @@ from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import BloomAttention from petals.bloom.block import WrappedBloomBlock -from petals.utils.linear8bitlt_patch import CustomLinear8bitLt use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -75,6 +73,12 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0): `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to `6.0` as described by the paper. """ + + # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes + import bitsandbytes as bnb + + from petals.utils.linear8bitlt_patch import CustomLinear8bitLt + for n, module in model.named_children(): if len(list(module.children())) > 0: replace_8bit_linear(module, threshold) @@ -98,7 +102,6 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0): def make_tensor_parallel( block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device ): - assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt)) tp_config = get_bloom_config(model_config, devices) del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True) From d1fa5eb26044ea471da9ef9ee0d2e6b33e534762 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 6 Jan 2023 04:57:49 +0300 Subject: [PATCH 013/168] hotfix: add initial peer that did not crash :) (#181) add hotfix initial peer (@borzunov's peers are down) --- src/petals/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/petals/constants.py b/src/petals/constants.py index a4620c3..27c8c70 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -3,4 +3,5 @@ PUBLIC_INITIAL_PEERS = [ "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", + "/ip4/193.106.95.184/tcp/46419/p2p/12D3KooWHqdGgDZZRCRDKqiikB1ofC3xLnV3oUynepUfDjNh5g9X", ] From 712f5a330f4c9e590d9d9d06a031045c04dd611e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 6 Jan 2023 10:51:12 +0400 Subject: [PATCH 014/168] Remove backup bootstrap peer --- src/petals/constants.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/petals/constants.py b/src/petals/constants.py index 27c8c70..a4620c3 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -3,5 +3,4 @@ PUBLIC_INITIAL_PEERS = [ "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", - "/ip4/193.106.95.184/tcp/46419/p2p/12D3KooWHqdGgDZZRCRDKqiikB1ofC3xLnV3oUynepUfDjNh5g9X", ] From 0f6464103d8e714a7e77051b66e64145e87479e8 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 7 Jan 2023 01:55:40 +0400 Subject: [PATCH 015/168] Remove protobuf from requirements (#182) A correct protobuf version should be already installed by hivemind. This also resolves version conflict on Colab, where protobuf versions required by Petals were different from the ones required by pre-installed tensorflow and tensorboard packages. Co-authored-by: Max Ryabinin --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8c7b19a..11513bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,6 @@ install_requires = accelerate==0.15.0 huggingface-hub==0.11.1 transformers==4.25.1 - protobuf>=3.20.3,<4.0dev speedtest-cli==2.1.3 hivemind==1.1.3 tensor_parallel==1.0.23 From 27406a9377b41e9a3ff95beda7d05699c264ace9 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 7 Jan 2023 10:48:51 +0400 Subject: [PATCH 016/168] Add more links to BLOOM to readme (#183) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e1eeb7f..f08d977 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

-Generate text using distributed BLOOM and fine-tune it for your own tasks: +Generate text using distributed [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine-tune it for your own tasks: ```python from petals import DistributedBloomForCausalLM @@ -58,7 +58,7 @@ Check out more examples and tutorials: ## How does it work? -- Petals runs large language models like BLOOM-176B **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. +- Petals runs large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. - Inference runs at ≈ 1 sec per step (token) — 10x faster than possible with offloading, enough for chatbots and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. You get the comforts of an API with the flexibility of PyTorch. From f344c7801b5849f5fc5cb30ff588b8575742c257 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 8 Jan 2023 08:41:47 +0400 Subject: [PATCH 017/168] Add link to health.petals.ml to readme (#184) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f08d977..323b8b0 100644 --- a/README.md +++ b/README.md @@ -48,11 +48,12 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache -- 💬 If you have any issues or feedback, please join [our Discord server](https://discord.gg/D9MwApKgWa)! -Check out more examples and tutorials: +Check out more examples, tools, and tutorials: -- Chatbot web app (connects to Petals via an HTTP endpoint): [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat) +- Chatbot web app (connects to Petals via an HTTP endpoint): [link](http://chat.petals.ml), [source code](https://github.com/borzunov/chat.petals.ml) - Training a personified chatbot: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) - Fine-tuning BLOOM for text semantic classification: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) +- Public swarm monitor: [link](http://health.petals.ml), [source code](https://github.com/borzunov/health.petals.ml) - Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Running a custom foundation model: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) From 391c8552084c80f2a1508f0a3ce41d665799e392 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 8 Jan 2023 09:59:07 +0400 Subject: [PATCH 018/168] Add readme subsections (#185) --- README.md | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 323b8b0..1c5603b 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,11 @@ for input_ids, labels in data_loader: 🚀  Try now in Colab

-Connect your own GPU and increase Petals capacity — run this in an [Anaconda](https://www.anaconda.com) env: +🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. + +### Connect your GPU and increase Petals capacity + +Run this in an [Anaconda](https://www.anaconda.com) env: ```bash conda install pytorch cudatoolkit=11.3 -c pytorch @@ -46,17 +50,29 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache -- learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals ``` -💬 If you have any issues or feedback, please join [our Discord server](https://discord.gg/D9MwApKgWa)! +🔒 This does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). + +💬 If you have any issues or feedback, let us know on [our Discord server](https://discord.gg/D9MwApKgWa)! + +### Check out examples, tutorials, and more + +Example apps built with Petals: + +- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP endpoint): [source code](https://github.com/borzunov/chat.petals.ml) + +Fine-tuning the model for your own tasks: -Check out more examples, tools, and tutorials: +- Training a personified chatbot: [tutorial](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) +- Fine-tuning BLOOM for text semantic classification: [tutorial](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) -- Chatbot web app (connects to Petals via an HTTP endpoint): [link](http://chat.petals.ml), [source code](https://github.com/borzunov/chat.petals.ml) -- Training a personified chatbot: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) -- Fine-tuning BLOOM for text semantic classification: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) -- Public swarm monitor: [link](http://health.petals.ml), [source code](https://github.com/borzunov/health.petals.ml) +Useful tools and advanced tutorials: + +- [Monitor](http://health.petals.ml) for the public swarm: [source code](https://github.com/borzunov/health.petals.ml) - Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Running a custom foundation model: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) +📋 If you build an app running BLOOM with Petals, make sure it follows the BLOOM's [terms of use](https://huggingface.co/bigscience/bloom). + ## How does it work? - Petals runs large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. @@ -71,14 +87,6 @@ Check out more examples, tools, and tutorials: 📜  Read paper

-### 🔒 Privacy and security - -The Petals public swarm is designed for research and academic use. **Please do not use the public swarm to process sensitive data.** We ask for that because it is an open network, and it is technically possible for peers serving model layers to recover input data and model outputs or modify them in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process your data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). - -### 📋 Model's terms of use - -Before building your own application that runs a language model with Petals, please check out the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license). - ## FAQ 1. **What's the motivation for people to host model layers in the public swarm?** From 16b69d6050cb16bc325c0bf9e4f855212b9726cb Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 9 Jan 2023 09:54:44 +0400 Subject: [PATCH 019/168] Fix GiBs in the "insufficient disk space" message (#187) --- src/petals/utils/disk_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py index 4d1dbef..eb23477 100644 --- a/src/petals/utils/disk_cache.py +++ b/src/petals/utils/disk_cache.py @@ -73,14 +73,14 @@ def free_disk_space_for( if freed_space >= extra_space_needed: break + gib = 1024**3 if pending_removal: - gib = 1024**3 logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space") delete_strategy = cache_info.delete_revisions(*pending_removal) delete_strategy.execute() if freed_space < extra_space_needed: raise RuntimeError( - f"Insufficient disk space to load a block. Please free {extra_space_needed - freed_space:.1f} GiB " + f"Insufficient disk space to load a block. Please free {(extra_space_needed - freed_space) / gib:.1f} GiB " f"on the volume for {cache_dir} or increase --max_disk_space if you set it manually" ) From 93bed7da5a128d99daf6ac94f671eb640510446c Mon Sep 17 00:00:00 2001 From: Egiazarian Vage Date: Mon, 9 Jan 2023 20:41:23 +0400 Subject: [PATCH 020/168] Support libp2p relays for NAT traversal (#186) - Added relay options to servers - Enabled relay options by default - Changed hivemind version to 1.1.5 - Moved reachability check to be performed after blocks are loaded Co-authored-by: Alexander Borzunov --- setup.cfg | 2 +- src/petals/cli/run_server.py | 3 ++ src/petals/client/remote_model.py | 2 ++ src/petals/server/reachability.py | 39 ++++++++++++++++++++++++ src/petals/server/server.py | 50 +++++++++++++------------------ 5 files changed, 66 insertions(+), 30 deletions(-) create mode 100644 src/petals/server/reachability.py diff --git a/setup.cfg b/setup.cfg index 11513bd..3ba993e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = huggingface-hub==0.11.1 transformers==4.25.1 speedtest-cli==2.1.3 - hivemind==1.1.3 + hivemind==1.1.5 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index e089937..fc0771d 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -38,6 +38,9 @@ def main(): 'This is a simplified way to set the --announce_maddrs option (see below).' 'Default: server announces IPv4/IPv6 addresses of your network interfaces') + parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay", + help="Do not look for libp2p relays to reach peers behind NATs/firewalls") + parser.add_argument('--host_maddrs', nargs='+', required=False, help='Multiaddrs to listen for external connections from other peers') parser.add_argument('--announce_maddrs', nargs='+', required=False, diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 3e52e40..5d22bfd 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -107,6 +107,8 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): num_workers=n_layer, startup_timeout=config.daemon_startup_timeout, start=True, + use_relay=True, + use_auto_relay=True, ) ) assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py new file mode 100644 index 0000000..d8b5fba --- /dev/null +++ b/src/petals/server/reachability.py @@ -0,0 +1,39 @@ +import math +import time + +import requests +from hivemind.utils.logging import get_logger + +logger = get_logger(__file__) + + +def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None: + for attempt_no in range(math.floor(wait_time / retry_delay) + 1): + try: + r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10) + r.raise_for_status() + response = r.json() + + if response["success"]: + logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon") + return + + if attempt_no == 0: + # Usually, libp2p manages to set up relays before we finish loading blocks. + # In other cases, we may need to wait for up to `wait_time` seconds before it's done. + logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes") + time.sleep(retry_delay) + except Exception as e: + logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}") + return + + raise RuntimeError( + f"Server has not become reachable from the Internet:\n\n" + f"{response['message']}\n\n" + f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n" + f" 1. Choose a specific port for the Petals server, for example, 31337.\n" + f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n" + f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n" + f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" + f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" + ) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a8927aa..e1a2293 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -10,7 +10,6 @@ from typing import Dict, List, Optional, Sequence, Union import numpy as np import psutil -import requests import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -28,6 +27,7 @@ from petals.server.backend import TransformerBackend from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache +from petals.server.reachability import check_reachability from petals.server.throughput import get_host_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -78,6 +78,8 @@ class Server: load_in_8bit: Optional[bool] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, + use_relay: bool = True, + use_auto_relay: bool = True, **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -117,14 +119,20 @@ class Server: ) self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] - self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs) + self.dht = DHT( + initial_peers=initial_peers, + start=True, + num_workers=self.block_config.n_layer, + use_relay=use_relay, + use_auto_relay=use_auto_relay, + **kwargs, + ) visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") - if not skip_reachability_check: - self._check_reachability() else: logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") + self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -196,35 +204,14 @@ class Server: self.stop = threading.Event() - def _check_reachability(self): - try: - r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10) - r.raise_for_status() - response = r.json() - except Exception as e: - logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}") - return - - if not response["success"]: - # This happens only if health.petals.ml is up and explicitly told us that we are unreachable - raise RuntimeError( - f"Server is not reachable from the Internet:\n\n" - f"{response['message']}\n\n" - f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n" - f" 1. Choose a specific port for the Petals server, for example, 31337.\n" - f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n" - f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n" - f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" - f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" - ) - - logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon") - def _choose_num_blocks(self) -> int: assert ( self.converted_model_name_or_path == "bigscience/bloom-petals" ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually" - assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually" + assert self.device.type == "cuda", ( + "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. " + "CPU-only servers in the public swarm are discouraged since they are much slower" + ) num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 if num_devices > 1: @@ -287,6 +274,7 @@ class Server: use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, + need_reachability_check=self.need_reachability_check, start=True, ) try: @@ -380,6 +368,7 @@ class ModuleContainer(threading.Thread): use_auth_token: Optional[str], load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], + need_reachability_check: bool, **kwargs, ) -> ModuleContainer: module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] @@ -433,6 +422,9 @@ class ModuleContainer(threading.Thread): min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) + + if need_reachability_check: + check_reachability(dht.peer_id) except: logger.debug("Shutting down backends") for backend in blocks.values(): From a617ce3cfa9e08eee2fe4982343e0689c5fc5cfe Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 10 Jan 2023 13:04:52 +0400 Subject: [PATCH 021/168] Fix psutil-related AccessDenied crash, disable --load_in_8bit by default in case of TP (#188) * Don't count open fds since it leads to AccessDenied crashes on some machines * Use --load_in_8bit=False by default in case of tensor parallelism * Install petals from PyPI in fine-tuning tutorials --- examples/prompt-tuning-personachat.ipynb | 3 +-- examples/prompt-tuning-sst2.ipynb | 3 +-- src/petals/server/server.py | 27 ++++++++++++------------ 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb index f6031ad..ff0eac7 100644 --- a/examples/prompt-tuning-personachat.ipynb +++ b/examples/prompt-tuning-personachat.ipynb @@ -36,8 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q git+https://github.com/bigscience-workshop/petals\n", - "!pip install -q datasets wandb" + "%pip install -q petals datasets wandb" ] }, { diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index a94a51b..bf985a9 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -36,8 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q git+https://github.com/bigscience-workshop/petals\n", - "!pip install -q datasets wandb" + "%pip install -q petals datasets wandb" ] }, { diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e1a2293..57d743e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -9,7 +9,6 @@ import time from typing import Dict, List, Optional, Sequence, Union import numpy as np -import psutil import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -28,7 +27,7 @@ from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import check_reachability -from petals.server.throughput import get_host_throughput +from petals.server.throughput import get_dtype_name, get_host_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -146,12 +145,6 @@ class Server: assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" self.torch_dtype = torch_dtype - if load_in_8bit is None: - load_in_8bit = device.type == "cuda" - if load_in_8bit: - logger.info("Model weights will be loaded in 8-bit format") - self.load_in_8bit = load_in_8bit - if tensor_parallel_devices is None: tensor_parallel_devices = (device,) self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices)) @@ -159,6 +152,17 @@ class Server: logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}") check_device_balance(self.tensor_parallel_devices) + if load_in_8bit is None: + load_in_8bit = device.type == "cuda" + if load_in_8bit and len(self.tensor_parallel_devices) > 1: + load_in_8bit = False + logger.warning( + "Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. " + "You can explicitly set `--load_in_8bit True` to override this" + ) + self.load_in_8bit = load_in_8bit + logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") + assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() @@ -167,8 +171,7 @@ class Server: first_block_index, last_block_index = block_indices.split(":") first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index))) except Exception as e: - logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)") - raise + raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)") block_indices = range(first_block_index, last_block_index) num_blocks = len(block_indices) self.strict_block_indices, self.num_blocks = block_indices, num_blocks @@ -301,10 +304,6 @@ class Server: del self.module_container gc.collect() # In particular, this closes unused file descriptors - cur_proc = psutil.Process() - num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)] - logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors") - if self.device.type == "cuda": torch.cuda.empty_cache() From 82c9f93ce6e43360cf47314f6dd4d52beafa3bfd Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 10 Jan 2023 15:47:58 +0400 Subject: [PATCH 022/168] Bump version to 1.1.0 (#190) --- setup.cfg | 2 +- src/petals/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3ba993e..ad197d1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = petals -version = 1.0.0 +version = attr: petals.__version__ author = Petals Developers author_email = petals-dev@googlegroups.com description = Easy way to efficiently run 100B+ language models without high-end GPUs diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 667094c..13b2cd4 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,6 +1,6 @@ from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.0.0" +__version__ = "1.1.0" _initialize_logs() From 487411e87ef5562936b22180766734dc5152a4f5 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 11 Jan 2023 02:28:49 +0400 Subject: [PATCH 023/168] Fix fine-tuning notebooks intros (#194) The notebook intros were outdated and mentioned the 6B model, while the actual code already runs the 176B model. This led to confusion among our users in Discord. --- examples/prompt-tuning-personachat.ipynb | 4 ++-- examples/prompt-tuning-sst2.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb index ff0eac7..943bb61 100644 --- a/examples/prompt-tuning-personachat.ipynb +++ b/examples/prompt-tuning-personachat.ipynb @@ -11,9 +11,9 @@ "\n", "# Distributed Bloom for Text Generation using Prompt Tuning\n", "\n", - "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", + "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", - "We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n", + "We will adapt BLOOM for the task of creating a chatbot with a specific personality using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n", "\n", "To use this notebook in Colab:\n", "\n", diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index bf985a9..d99a48d 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -11,9 +11,9 @@ "\n", "# Distributed Bloom for Text Classification using Prompt Tuning\n", "\n", - "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", + "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", - "We will adapt the BLOOM model for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", + "We will adapt BLOOM for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", "\n", "To use this notebook in Colab:\n", "\n", From 127cf66beed3e0f66c408af82ea70350253c9062 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 11 Jan 2023 16:37:30 +0400 Subject: [PATCH 024/168] Ignore network RPS if we failed to measure it (#198) --- src/petals/server/throughput.py | 34 ++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 73ad973..f491a10 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -101,25 +101,25 @@ def measure_throughput_info( logger.info( "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" ) - return min( - measure_network_rps(config), - measure_compute_rps( - config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices - ), - ) - -def measure_network_rps(config: BloomConfig) -> float: + result = measure_compute_rps( + config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + ) try: - s = speedtest.Speedtest() - s.get_servers() - s.get_best_server() - s.download() - s.upload() - network_info = s.results.dict() - except: - logger.error("Failed to measure network throughput:") - raise + result = min(result, measure_network_rps(config)) + except Exception: + logger.warning("Failed to measure network throughput:", exc_info=True) + logger.warning("Proceeding with the compute throughput only") + return result + + +def measure_network_rps(config: BloomConfig) -> Optional[float]: + s = speedtest.Speedtest() + s.get_servers() + s.get_best_server() + s.download() + s.upload() + network_info = s.results.dict() bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request From b4f3224cda1405a002643d25e42ab30313b1c266 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 11 Jan 2023 16:50:24 +0400 Subject: [PATCH 025/168] Make client ignore blacklist if all servers holding a block are blacklisted (#197) If all servers holding a certain block are blacklisted, we should display errors from them instead of raising `No peers holding blocks`. Indeed, if the error is client-caused, the client should learn its reason from the latest error messages. In turn, if the error is server/network-caused and we only have a few servers, we'd better know the error instead of banning all the servers and making the user think that no servers are available. --- src/petals/client/inference_session.py | 4 ++-- src/petals/client/routing/sequence_manager.py | 18 ++++++++++++++---- src/petals/client/sequential_autograd.py | 5 ++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index b7a068b..3d41b6f 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -17,7 +17,7 @@ from hivemind import ( serialize_torch_tensor, ) from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2PHandlerError, StubBase +from hivemind.p2p import StubBase from hivemind.proto import runtime_pb2 from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback @@ -305,7 +305,7 @@ class InferenceSession: self._sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None and not isinstance(e, P2PHandlerError): + if span is not None: self._sequence_manager.on_request_failure(span.peer_id) delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index bb93158..d77a575 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -156,10 +156,20 @@ class RemoteSequenceManager: for block_info in new_block_infos: if not block_info: continue - for peer_id in tuple(block_info.servers.keys()): - if peer_id in self.banned_peers: - logger.debug(f"Ignoring banned {peer_id} for block {block_info.uid}") - block_info.servers.pop(peer_id) + valid_servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if peer_id not in self.banned_peers + } + if len(valid_servers) < len(block_info.servers): + if valid_servers: + logger.debug( + f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}" + ) + block_info.servers = valid_servers + else: + # If we blacklisted all servers, the error may actually be client-caused + logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist") with self.lock_changes: self.sequence_info.update_(new_block_infos) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 8ee786d..30c20ad 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -10,7 +10,6 @@ from typing import List, Optional, Sequence, Tuple import torch from hivemind import MSGPackSerializer from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2PHandlerError from hivemind.utils.logging import get_logger from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward @@ -94,7 +93,7 @@ async def sequential_forward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None and not isinstance(e, P2PHandlerError): + if span is not None: sequence_manager.on_request_failure(span.peer_id) delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( @@ -171,7 +170,7 @@ async def sequential_backward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None and not isinstance(e, P2PHandlerError): + if span is not None: sequence_manager.on_request_failure(span.peer_id) delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( From c2cb6d19ae397bb1830478f635c34f03c5deedff Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 11 Jan 2023 17:54:24 +0300 Subject: [PATCH 026/168] Increase tolerances in test_tp_block (#196) deflapify tests --- tests/test_tensor_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 40eb1ee..9d3ba59 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -40,7 +40,7 @@ def test_tp_block(devices, custom_config): y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past) y_ours.backward(grad_proj) - assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6) - assert torch.allclose(y_ours, y_ref, atol=1e-6) + assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5) + assert torch.allclose(y_ours, y_ref, atol=1e-5) assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4) assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4) From 42d1bbb568b365943f011c7dfd531fc2b686d7bf Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 11 Jan 2023 22:27:14 +0400 Subject: [PATCH 027/168] Fix --no_auto_relay help (#199) --- src/petals/cli/run_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index fc0771d..ff68966 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -39,7 +39,7 @@ def main(): 'Default: server announces IPv4/IPv6 addresses of your network interfaces') parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay", - help="Do not look for libp2p relays to reach peers behind NATs/firewalls") + help="Do not look for libp2p relays to become reachable if we are behind NAT/firewall") parser.add_argument('--host_maddrs', nargs='+', required=False, help='Multiaddrs to listen for external connections from other peers') From 012f840f7e5c21f4ca5c9090e46af62f1fac7715 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 11 Jan 2023 23:26:09 +0300 Subject: [PATCH 028/168] Use length-weighted sampling in routing for inference (#204) This pull-request implements a simple (1) greedy (2) latency-agnostic routing optimization that should speed up both our use cases. Why this exists: our effort to merge full routing (ping-aware, throughut-aware, dijkstra) is in a sorry state between several branches; merging it into main would take many days. Co-authored-by: Aleksandr Borzunov --- src/petals/client/inference_session.py | 2 +- src/petals/client/routing/sequence_manager.py | 15 +++++++++++++-- src/petals/client/sequential_autograd.py | 2 +- tests/test_sequence_manager.py | 5 +++-- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 3d41b6f..95e5ff5 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -255,7 +255,7 @@ class InferenceSession: ) recovery_until = max(recovery_until, update_end) - updated_spans = self._sequence_manager.make_sequence(block_idx, update_end) + updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="fastest") # make_sequence() could return a longer sequence updated_spans[-1].end = min(updated_spans[-1].end, update_end) updated_sessions = self._enter_server_sessions(updated_spans) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index d77a575..2b282bc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -9,6 +9,7 @@ import time from typing import Any, Dict, List, Optional, Sequence, Union from weakref import WeakMethod +import numpy as np from hivemind import DHT, P2P, MSGPackSerializer, PeerID from hivemind.dht.node import Blacklist from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker @@ -92,12 +93,15 @@ class RemoteSequenceManager: if await_ready: self._thread.ready.wait(timeout) - def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]: + def make_sequence( + self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random" + ) -> List[RemoteSpanInfo]: """ Form a sequence of remote servers that collectively serve all consecutive layers :param start_index: optional index of the first module in a sequence, default = the first of block_uids :param end_index: optional index of the last module (non-inclusive), default = after last of block uids + :param mode: either random or fastest """ if not self.is_alive(): logger.error("Using a sequence manager that is not running: it has either crashed or never started") @@ -110,7 +114,14 @@ class RemoteSequenceManager: current_index = start_index while current_index < end_index: candidate_spans = self.sequence_info.spans_containing_block[current_index] - chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing + if mode == "random": + chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing + elif mode == "fastest": + # note: this too is a heuristic that will be replaced once we integrate fastest wall time routing + span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64) + chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) + else: + raise RuntimeError(f"Unexpected mode {mode}") assert chosen_span.start <= current_index < chosen_span.end span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id)) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 30c20ad..debcb7b 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -60,7 +60,7 @@ async def sequential_forward( span = None try: if not sequences or attempt_no >= 1: - sequences = deque(sequence_manager.make_sequence(block_idx, end_index)) + sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="random")) # make_sequence() could return a longer sequence sequences[-1].end = min(sequences[-1].end, end_index) logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers") diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 69d05c4..29562c3 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -14,7 +14,8 @@ logger = get_logger(__file__) @pytest.mark.forked -def test_sequence_manager_shutdown(): +@pytest.mark.parametrize("mode", ["fastest", "random"]) +def test_sequence_manager_basics(mode: str): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) sequential = RemoteSequential(config, dht) @@ -28,7 +29,7 @@ def test_sequence_manager_shutdown(): sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True), ) - sequence = sequential.sequence_manager.make_sequence() + sequence = sequential.sequence_manager.make_sequence(mode=mode) assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1)) assert sequential.sequence_manager.is_alive() From 37373a66c3010c03d259f18e6606ae68bbc22e18 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 12 Jan 2023 01:43:25 +0400 Subject: [PATCH 029/168] Update Anaconda installation commands (#205) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1c5603b..d742123 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ for input_ids, labels in data_loader: Run this in an [Anaconda](https://www.anaconda.com) env: ```bash -conda install pytorch cudatoolkit=11.3 -c pytorch +conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install -U petals python -m petals.cli.run_server bigscience/bloom-petals ``` @@ -106,7 +106,7 @@ Useful tools and advanced tutorials: Here's how to install Petals with conda: ```bash -conda install pytorch cudatoolkit=11.3 -c pytorch +conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install -U petals ``` @@ -122,7 +122,7 @@ __System requirements:__ Petals only supports Linux for now. If you don't have a Petals uses pytest with a few plugins. To install them, run: ```bash -conda install pytorch cudatoolkit=11.3 -c pytorch +conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia git clone https://github.com/bigscience-workshop/petals.git && cd petals pip install -e .[dev] ``` From 5f58f006495ca5fe1f96f74dd7e4de6315ff52d7 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 12 Jan 2023 06:49:41 +0300 Subject: [PATCH 030/168] Return available cache size in rpc_info() (#191) This PR makes servers return their free cache (in tokens * layers to make it compression-agnostic) To be used when calling make_sequence(optimize="inference") --- src/petals/server/backend.py | 5 +++++ src/petals/server/handler.py | 18 +++++++++++++++++ tests/test_server_stats.py | 39 ++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 tests/test_server_stats.py diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 67b03c0..4f9a3bb 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -1,6 +1,7 @@ """Code for serving bloom blocks via hivemind-server""" from __future__ import annotations +from collections import Counter from itertools import chain from typing import Any, Dict, Sequence, Tuple @@ -64,6 +65,10 @@ class TransformerBackend(ModuleBackend): self.kwargs_schema, ) + self.cache_bytes_per_token: Dict[torch.device, int] = Counter() + for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1): + self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8 + def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: """Create tensor descriptors for attention cache tensors used during inference_step""" head_dim = self.config.hidden_size // self.config.n_head diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 387431a..6ddfb55 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -33,6 +33,8 @@ from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__file__) +CACHE_TOKENS_AVAILABLE = "cache_tokens_available" + class TransformerConnectionHandler(ConnectionHandler): """Handles three request types: forward, backward and forward-incremental (inference)""" @@ -378,6 +380,22 @@ class TransformerConnectionHandler(ConnectionHandler): else: logger.warning(f"{message}: {warning}") + async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: + """Return metadata about stored block uids and current load""" + rpc_info = {} + if request.uid: + backend = self.module_backends[request.uid] + rpc_info.update(self.module_backends[request.uid].get_info()) + else: + backend = next(iter(self.module_backends.values())) + # not saving keys to rpc_info since user did not request any uid + + cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes) + if CACHE_TOKENS_AVAILABLE in rpc_info: + raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}") + rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values()) + return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info)) + async def _rpc_forward( *flat_tensors: torch.Tensor, diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py new file mode 100644 index 0000000..0f2b3f0 --- /dev/null +++ b/tests/test_server_stats.py @@ -0,0 +1,39 @@ +import time + +import hivemind +import pytest +import torch +from test_utils import * + +from petals.client import DistributedBloomConfig +from petals.data_structures import UID_DELIMITER +from petals.dht_utils import get_remote_sequence +from petals.server.handler import CACHE_TOKENS_AVAILABLE + + +@pytest.mark.forked +def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50): + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + + blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") + blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") + info_before = blocks1.sequence_manager.rpc_info + + with blocks1.inference_session(max_length=max_length) as sess: + sess.step(torch.randn(1, 1, config.hidden_size)) + blocks1.sequence_manager._rpc_info = None # invalidate cache + info_inside = blocks1.sequence_manager.rpc_info + + with blocks2.inference_session(max_length=max_length2) as sess2: + sess2.step(torch.randn(1, 1, config.hidden_size)) + blocks2.sequence_manager._rpc_info = None # invalidate cache + info_inside2 = blocks2.sequence_manager.rpc_info + + time.sleep(0.1) + blocks1.sequence_manager._rpc_info = None # invalidate cache + info_after = blocks1.sequence_manager.rpc_info + + assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE] + assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1) + assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2) From 771ca590e7d52e592741750970e300bdbd90d7c7 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 13 Jan 2023 02:05:39 +0300 Subject: [PATCH 031/168] Add service checking direct reachability from peers (#195) Servers joining from behind NATs/firewalls usually take several minutes to join a libp2p relay before they become accessible from the outside Internet. Moreover, requests to such servers are slower and more likely to fail (e.g., if the server switches a relay at the moment). If such servers host certain DHT keys, the swarm may occasionally lose read/write access to these keys, which results in: - Clients being unable to find any servers hosting a certain block. - All servers starting rebalancing to the same place to close the alleged "gap" in the swarm. This PRs modifies servers so that DHT keys are only hosted on **directly reachable** servers (the ones who aren't behind NAT/firewall). This way, DHT becomes more stable and works faster. Of course, trhe servers behind NATs/firewalls still accept requests for running inference/forward/backward for blocks they hold (it's more acceptable for this kind of requests to be slower or fail). Co-authored-by: Alexander Borzunov --- src/petals/cli/run_dht.py | 104 +++++++++++++++++++++++ src/petals/constants.py | 3 + src/petals/server/reachability.py | 135 +++++++++++++++++++++++++++++- src/petals/server/server.py | 22 +++-- 4 files changed, 254 insertions(+), 10 deletions(-) create mode 100644 src/petals/cli/run_dht.py diff --git a/src/petals/cli/run_dht.py b/src/petals/cli/run_dht.py new file mode 100644 index 0000000..2f30516 --- /dev/null +++ b/src/petals/cli/run_dht.py @@ -0,0 +1,104 @@ +""" +A copy of run_dht.py from hivemind with the ReachabilityProtocol added: +https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py + +This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm. + +This may be eventually merged to the hivemind upstream. +""" + +import time +from argparse import ArgumentParser +from secrets import token_hex + +from hivemind.dht import DHT, DHTNode +from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.networking import log_visible_maddrs + +from petals.server.reachability import ReachabilityProtocol + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__name__) + + +async def report_status(dht: DHT, node: DHTNode): + logger.info( + f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) " + f"are in the local routing table " + ) + logger.debug(f"Routing table contents: {node.protocol.routing_table}") + logger.info(f"Local storage contains {len(node.protocol.storage)} keys") + logger.debug(f"Local storage contents: {node.protocol.storage}") + + # Contact peers and keep the routing table healthy (remove stale PeerIDs) + await node.get(f"heartbeat_{token_hex(16)}", latest=True) + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--initial_peers", + nargs="*", + help="Multiaddrs of the peers that will welcome you into the existing DHT. " + "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY", + ) + parser.add_argument( + "--host_maddrs", + nargs="*", + default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"], + help="Multiaddrs to listen for external connections from other DHT instances. " + "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0", + ) + parser.add_argument( + "--announce_maddrs", + nargs="*", + help="Visible multiaddrs the host announces for external connections from other DHT instances", + ) + parser.add_argument( + "--use_ipfs", + action="store_true", + help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" ' + "part of the multiaddrs for the initial_peers " + "(no need to specify a particular IPv4/IPv6 host and port)", + ) + parser.add_argument( + "--identity_path", + help="Path to a private key file. If defined, makes the peer ID deterministic. " + "If the file does not exist, writes a new private key to this file.", + ) + parser.add_argument( + "--no_relay", + action="store_false", + dest="use_relay", + help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)", + ) + parser.add_argument( + "--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls" + ) + parser.add_argument( + "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT" + ) + + args = parser.parse_args() + + dht = DHT( + start=True, + initial_peers=args.initial_peers, + host_maddrs=args.host_maddrs, + announce_maddrs=args.announce_maddrs, + use_ipfs=args.use_ipfs, + identity_path=args.identity_path, + use_relay=args.use_relay, + use_auto_relay=args.use_auto_relay, + ) + log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) + + reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True) + + while True: + dht.run_coroutine(report_status, return_future=False) + time.sleep(args.refresh_period) + + +if __name__ == "__main__": + main() diff --git a/src/petals/constants.py b/src/petals/constants.py index a4620c3..da047f1 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -4,3 +4,6 @@ PUBLIC_INITIAL_PEERS = [ "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", ] + +# The reachability API is currently used only when connecting to the public swarm +REACHABILITY_API_URL = "http://health.petals.ml" diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index d8b5fba..7ead055 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -1,16 +1,30 @@ +import asyncio import math +import threading import time +from concurrent.futures import Future +from contextlib import asynccontextmanager +from functools import partial +from secrets import token_hex +from typing import Optional import requests -from hivemind.utils.logging import get_logger +from hivemind.dht import DHT, DHTNode +from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker +from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase +from hivemind.proto import dht_pb2 +from hivemind.utils import get_logger -logger = get_logger(__file__) +from petals.constants import REACHABILITY_API_URL +logger = get_logger(__name__) -def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None: + +def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None: + """verify that your peer is reachable from a (centralized) validator, whether directly or through a relay""" for attempt_no in range(math.floor(wait_time / retry_delay) + 1): try: - r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10) + r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10) r.raise_for_status() response = r.json() @@ -37,3 +51,116 @@ def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" ) + + +def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]: + """test if your peer is accessible by others in the swarm with the specified network options in **kwargs""" + + async def _check_direct_reachability(): + target_dht = await DHTNode.create(client_mode=True, **kwargs) + try: + protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p) + async with protocol.serve(target_dht.protocol.p2p): + successes = requests = 0 + for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()): + probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id) + if probe_available is None: + continue # remote peer failed to check probe + successes += probe_available + requests += 1 + if requests >= max_peers: + break + + logger.info(f"Direct reachability: {successes}/{requests}") + return (successes / requests) >= threshold if requests > 0 else None + finally: + await target_dht.shutdown() + + return RemoteExpertWorker.run_coroutine(_check_direct_reachability()) + + +STRIPPED_PROBE_ARGS = dict( + dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60 +) + + +class ReachabilityProtocol(ServicerBase): + """Mini protocol to test if a locally running peer is accessible by other devices in the swarm""" + + def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0): + self.probe = probe + self.wait_timeout = wait_timeout + self._event_loop = self._stop = None + + async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]: + """Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond""" + try: + request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes())) + timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2 + response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout) + logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}") + return response.available + except Exception as e: + logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True) + return None + + async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse: + """Help another peer to check its reachability""" + response = dht_pb2.PingResponse(available=True) + check_peer = PeerID(request.peer.node_id) + if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves + response.available = await self.call_check(check_peer, check_peer=check_peer) is True + logger.info( + f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, " + f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}" + ) + return response + + @asynccontextmanager + async def serve(self, p2p: P2P): + try: + await self.add_p2p_handlers(p2p) + yield self + finally: + await self.remove_p2p_handlers(p2p) + + @classmethod + def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]: + protocol = cls(**kwargs) + ready = Future() + + async def _serve_with_probe(): + try: + common_p2p = await dht.replicate_p2p() + protocol._event_loop = asyncio.get_event_loop() + protocol._stop = asyncio.Event() + + initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)] + for info in await common_p2p.list_peers(): + initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs) + protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS) + + ready.set_result(True) + logger.info("Reachability service started") + + async with protocol.serve(common_p2p): + await protocol._stop.wait() + except Exception as e: + logger.warning(f"Reachability service failed: {repr(e)}") + logger.debug("See detailed traceback below:", exc_info=True) + + if not ready.done(): + ready.set_exception(e) + finally: + if protocol is not None and protocol.probe is not None: + await protocol.probe.shutdown() + logger.debug("Reachability service shut down") + + threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start() + if await_ready: + ready.result() # Propagates startup exceptions, if any + return protocol + + def shutdown(self): + if self._event_loop is not None and self._stop is not None: + self._event_loop.call_soon_threadsafe(self._stop.set) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 57d743e..7e76080 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -26,7 +26,7 @@ from petals.server.backend import TransformerBackend from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache -from petals.server.reachability import check_reachability +from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.throughput import get_dtype_name, get_host_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -77,6 +77,7 @@ class Server: load_in_8bit: Optional[bool] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, + dht_client_mode: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, **kwargs, @@ -118,20 +119,27 @@ class Server: ) self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] + if dht_client_mode is None: + is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) + dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer + logger.info(f"This server will run DHT in {'client' if dht_client_mode else 'full peer'} mode") self.dht = DHT( initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, use_relay=use_relay, use_auto_relay=use_auto_relay, + client_mode=dht_client_mode, **kwargs, ) + self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None + visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") else: logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") - self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS + self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -277,7 +285,7 @@ class Server: use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, - need_reachability_check=self.need_reachability_check, + should_validate_reachability=self.should_validate_reachability, start=True, ) try: @@ -335,6 +343,8 @@ class Server: def shutdown(self): self.stop.set() + if self.reachability_protocol is not None: + self.reachability_protocol.shutdown() self.dht.shutdown() self.dht.join() @@ -367,7 +377,7 @@ class ModuleContainer(threading.Thread): use_auth_token: Optional[str], load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], - need_reachability_check: bool, + should_validate_reachability: bool, **kwargs, ) -> ModuleContainer: module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] @@ -422,8 +432,8 @@ class ModuleContainer(threading.Thread): max_batch_size=max_batch_size, ) - if need_reachability_check: - check_reachability(dht.peer_id) + if should_validate_reachability: + validate_reachability(dht.peer_id) except: logger.debug("Shutting down backends") for backend in blocks.values(): From 6b12b0d050f73826f6f66481d40146370e2bebbb Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 07:46:10 +0400 Subject: [PATCH 032/168] Report server version and dht.client_mode in rpc_info(), check for updates on startup (#209) This PR: 1. Shows the current Petals version and checks for updates on startup. 2. Reports the current version and DHT mode in `rpc_info()`, so it can be shown on http://health.petals.ml or used on clients for efficient routing. --- setup.cfg | 1 + src/petals/cli/run_server.py | 3 +++ src/petals/server/handler.py | 27 ++++++++++++++++----------- src/petals/server/server.py | 2 +- src/petals/utils/version.py | 26 ++++++++++++++++++++++++++ 5 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 src/petals/utils/version.py diff --git a/setup.cfg b/setup.cfg index ad197d1..73bb117 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ install_requires = humanfriendly async-timeout>=4.0.2 cpufeature>=0.2.0 + packaging>=23.0 [options.extras_require] dev = diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index ff68966..135720d 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -8,6 +8,7 @@ from humanfriendly import parse_size from petals.constants import PUBLIC_INITIAL_PEERS from petals.server.server import Server +from petals.utils.version import validate_version logger = get_logger(__file__) @@ -193,6 +194,8 @@ def main(): if load_in_8bit is not None: args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"] + validate_version() + server = Server( **args, host_maddrs=host_maddrs, diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 6ddfb55..3c889f6 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -24,6 +24,7 @@ from hivemind.utils.asyncio import amap_in_executor, anext from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming +import petals from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID from petals.server.backend import TransformerBackend from petals.server.memory_cache import Handle @@ -382,19 +383,23 @@ class TransformerConnectionHandler(ConnectionHandler): async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: """Return metadata about stored block uids and current load""" - rpc_info = {} - if request.uid: - backend = self.module_backends[request.uid] - rpc_info.update(self.module_backends[request.uid].get_info()) - else: - backend = next(iter(self.module_backends.values())) - # not saving keys to rpc_info since user did not request any uid + backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values())) cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes) - if CACHE_TOKENS_AVAILABLE in rpc_info: - raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}") - rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values()) - return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info)) + result = { + "version": petals.__version__, + "dht_client_mode": self.dht.client_mode, + CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()), + } + + if request.uid: + block_info = self.module_backends[request.uid].get_info() + common_keys = set(result.keys()) & set(block_info.keys()) + if common_keys: + raise RuntimeError(f"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}") + result.update(block_info) + + return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) async def _rpc_forward( diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 7e76080..dca2ccd 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -102,7 +102,7 @@ class Server: f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " f"Please specify --prefix manually when starting a server" ) - logger.info(f"Automatic dht prefix: {prefix}") + logger.debug(f"Automatic dht prefix: {prefix}") self.prefix = prefix if expiration is None: diff --git a/src/petals/utils/version.py b/src/petals/utils/version.py new file mode 100644 index 0000000..b992c27 --- /dev/null +++ b/src/petals/utils/version.py @@ -0,0 +1,26 @@ +import requests +from hivemind.utils.logging import TextStyle, get_logger +from packaging.version import parse + +import petals + +logger = get_logger(__file__) + + +def validate_version(): + logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}") + try: + r = requests.get("https://pypi.python.org/pypi/petals/json") + r.raise_for_status() + response = r.json() + + versions = [parse(ver) for ver in response.get("releases")] + latest = max(ver for ver in versions if not ver.is_prerelease) + + if parse(petals.__version__) < latest: + logger.info( + f"A newer version {latest} is available. Please upgrade with: " + f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}" + ) + except Exception as e: + logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True) From cc5e5d32c007ebfa21d8a1ceb6776851b16bb1bf Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 08:45:53 +0400 Subject: [PATCH 033/168] Don't switch blocks if it makes swarm disjoint (#210) Even if the swarm seems to have at least 2 servers for each block, turning off on one of the servers could break it. That's because once a server is turned off, others may move to a better position, creating a significant downtime on their way. This PR prohibits switching blocks if it would make the swarm disjoint along the way. --- src/petals/server/block_selection.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/petals/server/block_selection.py b/src/petals/server/block_selection.py index 203b781..33161b2 100644 --- a/src/petals/server/block_selection.py +++ b/src/petals/server/block_selection.py @@ -79,6 +79,9 @@ def should_choose_other_blocks( # Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer # the previous server position in case of other things being almost equal. + if initial_throughput > eps and throughputs.min() <= 0: + return False # Switching blocks would make the swarm disjoint + new_start = _choose_best_start(throughputs, local_span.length) if local_span.start == new_start: return False # This server is on its best place already From 6ba63c6cc8103e22d755258bf7cc78937262a61e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 16:27:10 +0400 Subject: [PATCH 034/168] Fix output shape when resuming generation (#211) Before this PR, `model.generate()` returned one excess token when resuming generation with an existing (the last token of the previous session, `session.last_token_id`). This is an unexpected behavior not convenient for the downstream apps, so this PR changes it until it's too late. --- setup.cfg | 2 +- src/petals/client/remote_generation.py | 11 +++++++---- src/petals/server/throughput.py | 2 ++ 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/setup.cfg b/setup.cfg index 73bb117..a05ae6b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ install_requires = humanfriendly async-timeout>=4.0.2 cpufeature>=0.2.0 - packaging>=23.0 + packaging>=20.9 [options.extras_require] dev = diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 4182ca8..af4166d 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -104,17 +104,18 @@ class RemoteGenerationMixin: elif max_length is None and max_new_tokens is not None: max_length = prefix_length + max_new_tokens - if num_beams > 1 and session is not None: + resuming_session = session is not None and session.last_token_id is not None + if num_beams > 1 and resuming_session: raise NotImplementedError( - "Reusing inference session in .generate() along with beam search is not supported yet" + "Resuming inference session in .generate() along with beam search is not supported yet" ) if inputs is not None: assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" - if session is not None and session.last_token_id is not None: + if resuming_session: inputs = torch.cat([session.last_token_id, inputs], dim=1) else: - if session is not None and session.last_token_id is not None: + if resuming_session: inputs = session.last_token_id else: assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs" @@ -207,6 +208,8 @@ class RemoteGenerationMixin: outputs = torch.cat(outputs, dim=-1) + if resuming_session: + outputs = outputs[:, 1:] if num_beams > 1: pre_return_idx = [ torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index f491a10..8b6dc9c 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -123,6 +123,8 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]: bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request + if network_rps == 0: + raise ValueError("speedtest has returned network_rps == 0") logger.info( f"Network throughput: " From 5ff250bee9407090fd0216db6ed3b289f6638280 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 17:53:00 +0400 Subject: [PATCH 035/168] Improve errors in case of missing blocks, suggest to join your own server (#212) --- .github/workflows/run-tests.yaml | 4 ++-- src/petals/client/routing/sequence_manager.py | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index af6299b..eb9c988 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -21,11 +21,11 @@ jobs: uses: actions/cache@v2 with: path: ~/.cache/pip - key: Key-v1-py3.9-${{ hashFiles('setup.cfg') }} + key: Key-v1-3.9-${{ hashFiles('setup.cfg') }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . + pip install .[dev] - name: Delete any test models older than 1 week run: | python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 2b282bc..441b9d4 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -114,6 +114,8 @@ class RemoteSequenceManager: current_index = start_index while current_index < end_index: candidate_spans = self.sequence_info.spans_containing_block[current_index] + if not candidate_spans: + raise MissingBlocksError(current_index) if mode == "random": chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing elif mode == "fastest": @@ -186,7 +188,7 @@ class RemoteSequenceManager: self.sequence_info.update_(new_block_infos) missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]] if missing_blocks: - raise MissingBlocksError(f"no servers holding blocks {missing_blocks}") + raise MissingBlocksError(missing_blocks) self.ready.set() # if there is an active server for every block, we may begin running break @@ -245,7 +247,7 @@ class RemoteSequenceManager: if server.state == ServerState.ONLINE ] if not active_servers: - raise MissingBlocksError("no servers holding the first block are online") + raise MissingBlocksError(0) peer_id = random.choice(active_servers) stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id) @@ -334,6 +336,11 @@ def maybe_log_traceback(exc: Exception): logger.log(traceback_level, "See detailed traceback below:", exc_info=True) -class MissingBlocksError(Exception): - def __repr__(self): - return self.args[0] +class MissingBlocksError(RuntimeError): + def __init__(self, block_indices: Union[int, Sequence[int]]): + super().__init__( + f"No servers holding blocks {block_indices} are online.\n" + f"You can check the public swarm's state at http://health.petals.ml\n\n" + f"If there are not enough servers, please consider connecting your own GPU:\n" + f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity" + ) From 825f5dbf2d7109338755f8a31bf679cd185b4d10 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 19:53:57 +0400 Subject: [PATCH 036/168] CI: Convert model only when convert_model.py or setup.cfg change (#213) This reduces the test running time by 2 times, unless convert_model.py or setup.cfg are changed. --- .github/workflows/run-tests.yaml | 26 ++++++++++++++++++++------ src/petals/cli/convert_model.py | 6 +++++- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index eb9c988..54614f3 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -12,32 +12,45 @@ jobs: BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }} timeout-minutes: 15 steps: - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v2 + - name: Check if the model is cached + id: cache-model + uses: actions/cache@v2 + with: + path: ~/.dummy + key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - name: Set up Python + if: steps.cache-model.outputs.cache-hit != 'true' uses: actions/setup-python@v2 with: python-version: 3.9 - name: Cache dependencies + if: steps.cache-model.outputs.cache-hit != 'true' uses: actions/cache@v2 with: path: ~/.cache/pip key: Key-v1-3.9-${{ hashFiles('setup.cfg') }} - name: Install dependencies + if: steps.cache-model.outputs.cache-hit != 'true' run: | python -m pip install --upgrade pip - pip install .[dev] + pip install . - name: Delete any test models older than 1 week + if: steps.cache-model.outputs.cache-hit != 'true' run: | python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN - name: Delete previous version of this model, if exists + if: steps.cache-model.outputs.cache-hit != 'true' run: | export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \ repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true - name: Convert model and push to hub + if: steps.cache-model.outputs.cache-hit != 'true' run: | - export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") - python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ + export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} + python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \ --resize_token_embeddings 50000 @@ -50,7 +63,8 @@ jobs: fail-fast: false timeout-minutes: 15 steps: - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: @@ -66,7 +80,7 @@ jobs: pip install .[dev] - name: Test run: | - export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") + export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG export REF_NAME=bigscience/bloom-560m diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index c4746fd..289c764 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -18,7 +18,7 @@ logger = get_logger(__file__) DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.") parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained") @@ -90,3 +90,7 @@ if __name__ == "__main__": config.save_pretrained(".") logger.info(f"Converted {args.model} and pushed to {args.output_repo}") + + +if __name__ == "__main__": + main() From 702bb5a2c215f5e65f880a3cc805f2c120d7aba1 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 20:16:31 +0400 Subject: [PATCH 037/168] CI: Update deprecated actions, don't measure network RPS (#215) * CI: Switch to actions/cache@v3 (v2 is deprecated) * Don't run measure_network_rps() in tests since it doesn't work well in CI --- .github/workflows/run-tests.yaml | 6 +++--- tests/test_aux_functions.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 54614f3..50509dc 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@v2 - name: Check if the model is cached id: cache-model - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.dummy key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} @@ -27,7 +27,7 @@ jobs: python-version: 3.9 - name: Cache dependencies if: steps.cache-model.outputs.cache-hit != 'true' - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pip key: Key-v1-3.9-${{ hashFiles('setup.cfg') }} @@ -70,7 +70,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Cache dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pip key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }} diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 554127f..1986f0a 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -8,7 +8,7 @@ from petals.server.throughput import measure_compute_rps, measure_network_rps @pytest.mark.forked @pytest.mark.parametrize("tensor_parallel", [False, True]) -def test_throughput_basic(tensor_parallel: bool): +def test_compute_throughput(tensor_parallel: bool): config = DistributedBloomConfig.from_pretrained(MODEL_NAME) tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( @@ -20,5 +20,3 @@ def test_throughput_basic(tensor_parallel: bool): n_steps=10, ) assert isinstance(compute_rps, float) and compute_rps > 0 - network_rps = measure_network_rps(config) - assert isinstance(network_rps, float) and network_rps > 0 From cea83d3356a2719a7a3a6e83db7ee0bea0cdc5a2 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 14 Jan 2023 00:34:46 +0400 Subject: [PATCH 038/168] Bump version to 1.1.1 (#214) --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 13b2cd4..25513a4 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,6 +1,6 @@ from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.1.0" +__version__ = "1.1.1" _initialize_logs() From af3da5bb04bf3729c76f6a182d5f82363c0572d3 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 16 Jan 2023 01:53:09 +0400 Subject: [PATCH 039/168] Choose --num_blocks automatically for all models (#217) --- src/petals/server/server.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index dca2ccd..a411fd3 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -216,9 +216,6 @@ class Server: self.stop = threading.Event() def _choose_num_blocks(self) -> int: - assert ( - self.converted_model_name_or_path == "bigscience/bloom-petals" - ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually" assert self.device.type == "cuda", ( "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. " "CPU-only servers in the public swarm are discouraged since they are much slower" @@ -240,10 +237,12 @@ class Server: total_memory = torch.cuda.get_device_properties(self.device).total_memory block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit) + + # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models gib = 1024**3 attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size + autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward - autograd_memory = 2 * gib * num_devices # gpu memory used for intermediate tensors in rpc_backward num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" From e651d73f112c4ff97f1f09d25d452edf492be819 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 16 Jan 2023 04:35:06 +0400 Subject: [PATCH 040/168] Add one more link to the "Getting started" tutorial (#218) Some people miss the "Try now in Colab" link or don't understand that it leads to the comprehensive tutorial, so I added one more explicit link. --- README.md | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d742123..c500f7c 100644 --- a/README.md +++ b/README.md @@ -54,22 +54,23 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache -- 💬 If you have any issues or feedback, let us know on [our Discord server](https://discord.gg/D9MwApKgWa)! -### Check out examples, tutorials, and more +### Check out tutorials, examples, and more -Example apps built with Petals: +Basic tutorials: -- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP endpoint): [source code](https://github.com/borzunov/chat.petals.ml) +- Getting started: [tutorial](https://colab.research.google.com/drive/1Ervk6HPNS6AYVr3xVdQnY5a-TjjmLCdQ?usp=sharing) +- Fine-tune BLOOM to be a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) +- Fine-tune BLOOM for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) -Fine-tuning the model for your own tasks: +Example apps built with Petals: -- Training a personified chatbot: [tutorial](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) -- Fine-tuning BLOOM for text semantic classification: [tutorial](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) +- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP endpoint): [source code](https://github.com/borzunov/chat.petals.ml) -Useful tools and advanced tutorials: +Useful tools and advanced guides: - [Monitor](http://health.petals.ml) for the public swarm: [source code](https://github.com/borzunov/health.petals.ml) -- Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) -- Running a custom foundation model: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) +- Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) +- Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) 📋 If you build an app running BLOOM with Petals, make sure it follows the BLOOM's [terms of use](https://huggingface.co/bigscience/bloom). From fa5ac6e3b46e993e180a496555237ff949deec8e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 18 Jan 2023 03:23:21 +0400 Subject: [PATCH 041/168] Mention BLOOMZ in readme (#221) --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c500f7c..e6c6628 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

-Generate text using distributed [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine-tune it for your own tasks: +Generate text using distributed 176B-parameter [BLOOM](https://huggingface.co/bigscience/bloom) or [BLOOMZ](https://huggingface.co/bigscience/bloomz) and fine-tune them for your own tasks: ```python from petals import DistributedBloomForCausalLM @@ -50,6 +50,8 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache -- learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals ``` +You can also host [BLOOMZ](https://huggingface.co/bigscience/bloomz), a version of BLOOM fine-tuned to follow human instructions in the zero-shot regime — just replace `bloom-petals` with `bloomz-petals`. + 🔒 This does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 💬 If you have any issues or feedback, let us know on [our Discord server](https://discord.gg/D9MwApKgWa)! From 3189b395f0031daaaf529d571773dc9fb8200205 Mon Sep 17 00:00:00 2001 From: Shuchang Zhou Date: Thu, 19 Jan 2023 22:38:43 +0800 Subject: [PATCH 042/168] Fix a typo in error message (#227) By the code context, it can be inferred that do_sample==False when control reaches this point. --- src/petals/client/remote_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index af4166d..4ea0c9b 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -129,7 +129,7 @@ class RemoteGenerationMixin: decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size) else: if top_k is not None or top_p is not None: - logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling") + logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling") decoding_algorithm = GreedyAlgorithm() if num_beams > 1: From c4938bc23efe22e3ab6d638261bfd56c6ad807a9 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 19 Jan 2023 18:38:21 +0300 Subject: [PATCH 043/168] Merge inference pools into one to increase inference speed (#225) It turns out using a separate pool for each block has led to significant slowdown, see #224 for details. --- .github/workflows/check-style.yaml | 4 +- .github/workflows/push-docker-image.yaml | 2 +- .github/workflows/run-tests.yaml | 4 +- src/petals/data_structures.py | 2 + src/petals/server/backend.py | 61 +++++++++++++++++++----- src/petals/server/handler.py | 53 +++++++++----------- src/petals/server/server.py | 16 +++++-- src/petals/server/task_prioritizer.py | 4 +- 8 files changed, 93 insertions(+), 53 deletions(-) diff --git a/.github/workflows/check-style.yaml b/.github/workflows/check-style.yaml index 94b9517..42e1460 100644 --- a/.github/workflows/check-style.yaml +++ b/.github/workflows/check-style.yaml @@ -9,7 +9,7 @@ jobs: black: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: psf/black@stable with: options: "--check --diff" @@ -17,7 +17,7 @@ jobs: isort: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v2 with: python-version: 3.8 diff --git a/.github/workflows/push-docker-image.yaml b/.github/workflows/push-docker-image.yaml index cbad1b2..345b8f2 100644 --- a/.github/workflows/push-docker-image.yaml +++ b/.github/workflows/push-docker-image.yaml @@ -14,7 +14,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Docker meta id: meta diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 50509dc..c1a01e8 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -13,7 +13,7 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Check if the model is cached id: cache-model uses: actions/cache@v3 @@ -64,7 +64,7 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 with: diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index d5a7181..5d85f07 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Any, Dict, Tuple from hivemind import PeerID +from hivemind.moe.expert_uid import ExpertUID from petals.server.memory_cache import Handle @@ -48,5 +49,6 @@ RPCInfo = Dict[str, Any] @dataclasses.dataclass(frozen=True) class InferenceMetadata: + uid: ExpertUID prefix_length: int cache_handles: Tuple[Handle, ...] diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 4f9a3bb..81f3a33 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -3,10 +3,11 @@ from __future__ import annotations from collections import Counter from itertools import chain -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import torch from hivemind import BatchTensorDescriptor, TensorDescriptor +from hivemind.moe.expert_uid import ExpertUID from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger from tensor_parallel import TensorParallel @@ -15,7 +16,7 @@ from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import BloomAttention from petals.data_structures import InferenceMetadata -from petals.server.memory_cache import MemoryCache +from petals.server.memory_cache import Handle, MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy @@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend): device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" - ) + ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" ) @@ -79,22 +80,20 @@ class TransformerBackend(ModuleBackend): cache_tensors.extend((keys, values)) return cache_tensors + @torch.inference_mode() def inference_step( self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: - with torch.inference_mode(): - assert ( - hidden_states.ndim == 3 - ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: - self._reorder_cache_inplace(cache_tensors, hypo_ids) - layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) - hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) - self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) - return (hidden_states,) + assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" + with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + self._reorder_cache_inplace(cache_tensors, hypo_ids) + layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) + hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) + self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) + return (hidden_states,) def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor): """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" @@ -139,3 +138,39 @@ class TransformerBackend(ModuleBackend): dummy = torch.tensor([]) for p in self.module.parameters(): p.data = dummy + + +def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): + """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" + assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values()) + first_pool = next(iter(backends.values())).inference_pool + merged_pool = PrioritizedTaskPool( + _MergedInferenceStep(backends), + max_batch_size=first_pool.max_batch_size, + device=first_pool.device, + name=f"merged_inference", + ) + for backend in backends.values(): + assert not backend.inference_pool.is_alive() + backend.inference_pool = merged_pool + + +class _MergedInferenceStep: + def __init__(self, backends: Dict[ExpertUID, TransformerBackend]): + self.backends = backends + + def __call__( + self, + hidden_states: torch.Tensor, + hypo_ids: torch.LongTensor, + inference_infos: Sequence[InferenceMetadata], + *optional_prompts: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, ...]: + assert len(inference_infos) == len( + optional_prompts + ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts" + for inference_info, optional_prompt in zip(inference_infos, optional_prompts): + if optional_prompt is not None: + hidden_states[:, : optional_prompt.shape[1]] += optional_prompt + (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info) + return (hidden_states,) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 3c889f6..b1c36ed 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -141,10 +141,11 @@ class TransformerConnectionHandler(ConnectionHandler): assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" # parse deep prompts (optional argument) - if prompts is None or is_dummy(prompts) or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) + if prompts is None or is_dummy(prompts): + prompts = [None] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] if not (len(requested_backends) == len(prompts)): raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") @@ -156,33 +157,26 @@ class TransformerConnectionHandler(ConnectionHandler): f" exceeds pre-allocated maximum {max_length}" ) - # run request tensors through all requested modules, update caches - for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts): - if not is_dummy(prompt): - hidden_states[:, : prompt.shape[1]] += prompt - if hidden_states.numel() == 0: - continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - - metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles)) - assert isinstance( - hidden_states, torch.Tensor - ), f"hidden states must be tensor, got {type(hidden_states)}" - assert ( - hidden_states.ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - assert isinstance( - backend.inference_pool, PrioritizedTaskPool - ), "petals support only prioritized pools" - priority = self._prioritizer.prioritize( - hidden_states, - hypo_ids, - points=point_per_piece / len(requested_backends), - backend=backend, - type="inference", - ) - (hidden_states,) = await backend.inference_pool.submit_task( - hidden_states, hypo_ids, metadata, priority=priority + priority = self._prioritizer.prioritize( + hidden_states, + hypo_ids, + points=point_per_piece, + requested_uids=requested_uids, + type="inference", + ) + + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles)) + for uid, handles in zip(requested_uids, cache_handles) + ) + + if hidden_states.numel() == 0: + pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache + else: + assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" + (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority ) # serialize and send last layer outputs @@ -444,7 +438,6 @@ async def _rpc_forward( hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - # Serialize the overall output return hidden_states diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a411fd3..a25fce6 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -22,7 +22,7 @@ from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection -from petals.server.backend import TransformerBackend +from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache @@ -453,11 +453,12 @@ class ModuleContainer(threading.Thread): joining_announcer.stop.set() joining_announcer.join() + merge_inference_pools_inplace(blocks) + return cls( dht, blocks, throughput=throughput, - device=device, update_period=update_period, expiration=expiration, **kwargs, @@ -476,7 +477,6 @@ class ModuleContainer(threading.Thread): request_timeout: float, session_timeout: float, step_timeout: float, - device: Union[str, torch.device], start: bool, **kwargs, ): @@ -495,7 +495,7 @@ class ModuleContainer(threading.Thread): ) for _ in range(num_handlers) ] - self.runtime = Runtime(self.module_backends, device=None, **kwargs) + self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. self.online_announcer = ModuleAnnouncerThread( list(self.module_backends.keys()), @@ -633,3 +633,11 @@ class ModuleAnnouncerThread(threading.Thread): ) if self.stop.wait(self.update_period): break + + +class RuntimeWithDeduplicatedPools(Runtime): + """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pools = tuple(set(self.pools)) diff --git a/src/petals/server/task_prioritizer.py b/src/petals/server/task_prioritizer.py index 3ec5a90..6490fc5 100644 --- a/src/petals/server/task_prioritizer.py +++ b/src/petals/server/task_prioritizer.py @@ -16,4 +16,6 @@ class DummyTaskPrioritizer(TaskPrioritizerBase): """Simple implementation of TaskPrioritizer which gives constant zero priority for every task""" def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: - return 0.0 + if kwargs.get("type") == "inference": + return 1.0 # inference steps go first since they are more latency-sensitive + return 2.0 # forward, backward From 0ebf6de11770b0d069b059b3cbb9858a089a19e8 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Sat, 21 Jan 2023 04:05:41 +0100 Subject: [PATCH 044/168] Add citation to readme (#219) Co-authored-by: Alexander Borzunov --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index e6c6628..7988473 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,22 @@ The automated tests use a more complex server configuration that can be found [h We use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests. Before committing your code, simply run `black . && isort .` and you will be fine. +## 📜 Citation + +Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel. +[Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188) +_arXiv preprint arXiv:2209.01188,_ 2022. + +```bibtex +@article{borzunov2022petals, + title = {Petals: Collaborative Inference and Fine-tuning of Large Models}, + author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Ryabinin, Max and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin}, + journal = {arXiv preprint arXiv:2209.01188}, + year = {2022}, + url = {https://arxiv.org/abs/2209.01188} +} +``` + --------------------------------------------------------------------------------

From d4c687daca8c68a631cc42db59dca152bf3a4d98 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Mon, 23 Jan 2023 05:09:14 +0400 Subject: [PATCH 045/168] Fix dtype error in fine-tuning notebooks (#231) --- examples/prompt-tuning-sst2.ipynb | 11 +++++++---- src/petals/client/remote_model.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index d99a48d..5bcb0c9 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -308,6 +308,7 @@ " self.distributed_layers = model.transformer.h\n", "\n", " self.hidden_size = model.config.hidden_size\n", + " self.dtype = model.config.torch_dtype\n", " self.intermediate_size = intermediate_size\n", " self.num_classes = num_classes\n", " self.adapter_layer_position = adapter_layer_position\n", @@ -316,11 +317,11 @@ " self.adapter = nn.Sequential(\n", " nn.Linear(self.hidden_size, self.intermediate_size),\n", " nn.Linear(self.intermediate_size, self.hidden_size),\n", - " )\n", + " ).to(self.dtype)\n", " self.head = nn.Sequential(\n", " nn.LayerNorm(self.hidden_size),\n", " nn.Linear(self.hidden_size, self.num_classes),\n", - " )\n", + " ).to(self.dtype)\n", " \n", " def forward(self, embeddings):\n", " before_layers = self.distributed_layers[0:self.adapter_layer_position]\n", @@ -388,9 +389,10 @@ " head_layer_position=HEAD_LAYER_POSITION,\n", ")\n", "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "cls_criterion = nn.CrossEntoryCriterion()\n", "\n", "lr_scheduler = get_scheduler(\n", - " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", + " name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", ")" ] }, @@ -432,6 +434,7 @@ " with torch.no_grad():\n", " embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n", " outputs = cls_model(embeddings_output)\n", + " loss = cls_criterion(outputs, batch[\"labels\"])\n", " loss.backward()\n", "\n", " cls_optimizer.step()\n", @@ -461,7 +464,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.9 (default, Apr 13 2022, 08:48:07) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]" + "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 5d22bfd..af8a20c 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -265,7 +265,7 @@ class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequ self.num_labels = config.num_labels self.transformer = DistributedBloomModel(config) - self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) # Initialize weights and apply final processing self.post_init() From 5d7395e1b55a9b4b308ab0a4b0df818152148514 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 24 Jan 2023 10:01:31 +0400 Subject: [PATCH 046/168] Prompt-tuning notebooks: suggest to use a smaller model for faster prototyping (#234) --- README.md | 4 ++-- examples/prompt-tuning-personachat.ipynb | 16 +++++++++++++--- examples/prompt-tuning-sst2.ipynb | 16 +++++++++++++--- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 7988473..0341fe5 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,8 @@ You can also host [BLOOMZ](https://huggingface.co/bigscience/bloomz), a version Basic tutorials: - Getting started: [tutorial](https://colab.research.google.com/drive/1Ervk6HPNS6AYVr3xVdQnY5a-TjjmLCdQ?usp=sharing) -- Fine-tune BLOOM to be a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) -- Fine-tune BLOOM for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) +- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) +- Prompt-tune BLOOM for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) Example apps built with Petals: diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb index 943bb61..bd4d2ae 100644 --- a/examples/prompt-tuning-personachat.ipynb +++ b/examples/prompt-tuning-personachat.ipynb @@ -75,7 +75,18 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_NAME = \"bigscience/bloom-petals\" # select model you like\n", + "# Choose a model you'd like to prompt-tune. We recommend starting with\n", + "# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n", + "# Once your code is ready, you can switch to full-scale\n", + "# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n", + "MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n", + "\n", + "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", + "# The latter fine-tunes separate prefixes for each transformer block,\n", + "# so prompt-tuning will take more time but yield better results.\n", + "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n", + "TUNING_MODE = 'ptune'\n", + "\n", "NUM_PREFIX_TOKENS = 16\n", "DEVICE = 'cuda'\n", "BATCH_SIZE = 8\n", @@ -83,8 +94,7 @@ "WEIGHT_DECAY = 0.0\n", "NUM_SAMPLES = 1000\n", "SEED = 42\n", - "MODEL_MAX_LENGTH = 256\n", - "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] " + "MODEL_MAX_LENGTH = 256" ] }, { diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 5bcb0c9..05938b0 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -77,7 +77,18 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_NAME = \"bigscience/bloom-petals\" # select model you like\n", + "# Choose a model you'd like to prompt-tune. We recommend starting with\n", + "# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n", + "# Once your code is ready, you can switch to full-scale\n", + "# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n", + "MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n", + "\n", + "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", + "# The latter fine-tunes separate prefixes for each transformer block,\n", + "# so prompt-tuning will take more time but yield better results.\n", + "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n", + "TUNING_MODE = 'ptune'\n", + "\n", "NUM_PREFIX_TOKENS = 16\n", "DEVICE = 'cuda'\n", "BATCH_SIZE = 16\n", @@ -85,8 +96,7 @@ "WEIGHT_DECAY = 0.0\n", "NUM_EPOCHS = 3\n", "SEED = 42\n", - "MODEL_MAX_LENGTH = 64\n", - "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] " + "MODEL_MAX_LENGTH = 64" ] }, { From b03efb1ef500aa65e120a38aa09f340552249542 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 30 Jan 2023 23:17:38 +0300 Subject: [PATCH 047/168] Bump version to 1.1.2 (#244) --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 25513a4..ec44076 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,6 +1,6 @@ from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.1.1" +__version__ = "1.1.2" _initialize_logs() From 5367523df8e3524ba321671528c5b292ddc3ea8f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 31 Jan 2023 19:06:51 +0600 Subject: [PATCH 048/168] Fix typo in prompt-tuning-sst2.ipynb (#245) --- examples/prompt-tuning-sst2.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 05938b0..9ab9469 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -399,7 +399,7 @@ " head_layer_position=HEAD_LAYER_POSITION,\n", ")\n", "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", - "cls_criterion = nn.CrossEntoryCriterion()\n", + "cls_criterion = nn.CrossEntropyCriterion()\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", From 8766a14d28b7f7e76c24acdb12caa7a6088949e6 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 1 Feb 2023 14:10:45 +0300 Subject: [PATCH 049/168] Minor changes to examples/prompt-tuning notebooks (#247) Minor code changes required to run the notebook in a clean python environment --- examples/prompt-tuning-personachat.ipynb | 4 ++-- examples/prompt-tuning-sst2.ipynb | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb index bd4d2ae..b9d1bf5 100644 --- a/examples/prompt-tuning-personachat.ipynb +++ b/examples/prompt-tuning-personachat.ipynb @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install -q petals datasets wandb" + "%pip install -q petals datasets wandb scikit-learn" ] }, { @@ -285,7 +285,7 @@ " user_phrase = input()\n", " if len(user_phrase) == 0:\n", " break\n", - " inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n", + " inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n", " while True:\n", " outputs = model.generate(\n", " inputs,\n", diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 9ab9469..54840b1 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -36,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install -q petals datasets wandb" + "%pip install -q petals datasets wandb scikit-learn" ] }, { @@ -390,16 +390,14 @@ "metadata": {}, "outputs": [], "source": [ - "model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n", - "\n", "cls_model = BloomBasedClassifier(\n", - " model,\n", + " DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n", " intermediate_size=INTERMEDIATE_SIZE,\n", " adapter_layer_position=ADAPTER_LAYER_POSITION,\n", " head_layer_position=HEAD_LAYER_POSITION,\n", - ")\n", + ").to(DEVICE)\n", "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", - "cls_criterion = nn.CrossEntropyCriterion()\n", + "cls_criterion = nn.CrossEntropyLoss()\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", @@ -442,7 +440,7 @@ "\n", " cls_model.train()\n", " with torch.no_grad():\n", - " embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n", + " embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n", " outputs = cls_model(embeddings_output)\n", " loss = cls_criterion(outputs, batch[\"labels\"])\n", " loss.backward()\n", @@ -453,7 +451,7 @@ "\n", " wandb.log({\"Train Loss\": loss})\n", "\n", - " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", + " accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n", " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" ] } From b8a6788490b3ef36935d9581d4a8d6523c6aae9b Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 1 Feb 2023 21:32:27 +0300 Subject: [PATCH 050/168] Fix examples/sst, add cls_model embeddings (#248) --- examples/prompt-tuning-sst2.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 54840b1..c5dac6a 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -288,7 +288,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "1bbf014f", "metadata": {}, @@ -324,6 +323,7 @@ " self.adapter_layer_position = adapter_layer_position\n", " self.head_layer_position = head_layer_position\n", " \n", + " self.word_embeddings = model.transformer.word_embeddings\n", " self.adapter = nn.Sequential(\n", " nn.Linear(self.hidden_size, self.intermediate_size),\n", " nn.Linear(self.intermediate_size, self.hidden_size),\n", @@ -440,7 +440,7 @@ "\n", " cls_model.train()\n", " with torch.no_grad():\n", - " embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n", + " embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n", " outputs = cls_model(embeddings_output)\n", " loss = cls_criterion(outputs, batch[\"labels\"])\n", " loss.backward()\n", @@ -458,7 +458,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.9 64-bit", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -472,7 +472,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]" + "version": "3.8.8" }, "vscode": { "interpreter": { From 3c523ab0d2e1f16381724f6e6f288cc9158ae086 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 3 Feb 2023 01:04:19 +0600 Subject: [PATCH 051/168] Fix TP crashing when hypo_ids are used (#249) --- src/petals/server/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 81f3a33..cd8dce4 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -99,7 +99,7 @@ class TransformerBackend(ModuleBackend): """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" if not is_dummy(hypo_ids): for cache_tensor in cache_tensors: - cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids + cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]: """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past""" From 9954cb84fed2b4597f0c9d99b38472c0060b49bf Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 6 Feb 2023 01:22:18 +0300 Subject: [PATCH 052/168] Add `allowed_servers`, `max_retries` options to the client, improve logs (#235) --- src/petals/client/inference_session.py | 5 ++-- src/petals/client/remote_model.py | 12 ++++++-- src/petals/client/remote_sequential.py | 11 ++++++- src/petals/client/routing/sequence_manager.py | 29 +++++++++++++++++-- src/petals/client/sequential_autograd.py | 10 ++++--- src/petals/server/block_selection.py | 3 +- 6 files changed, 57 insertions(+), 13 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 95e5ff5..902dd0f 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -307,10 +307,11 @@ class InferenceSession: except Exception as e: if span is not None: self._sequence_manager.on_request_failure(span.peer_id) + if attempt_no + 1 == self._sequence_manager.max_retries: + raise delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( - f"Caught exception when running inference from block {block_idx} " - f"(retry in {delay:.0f} sec): {repr(e)}" + f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) time.sleep(delay) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index af8a20c..6f8ebf1 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from typing import List, Optional, Union +from typing import Collection, List, Optional, Union import hivemind import torch @@ -35,6 +35,10 @@ class DistributedBloomConfig(BloomConfig): daemon_startup_timeout: int = 30 dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models request_timeout: int = 30 # a number of seconds for waiting result from each node + max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + allowed_servers: Optional[ + Collection[Union[str, hivemind.PeerID]] + ] = None # if defined, send requests only to these servers pre_seq_len: int = 0 # a number of tokens for prompt tuning. tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters'] @@ -112,7 +116,11 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): ) ) assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" - self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout) + self.h = RemoteSequential( + config, + dht, + config.dht_prefix, + ) # Forbid accumulate grads for embeddings and layernorm self.set_requires_grad(False) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 2dc3c5b..6b9841e 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -41,7 +41,16 @@ class RemoteSequential(nn.Module): block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)) if sequence_manager is None: logger.debug(f"Creating new sequence manager for block uids: {block_uids}") - self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, start=True, **kwargs) + self.sequence_manager = RemoteSequenceManager( + dht, + block_uids, + self.p2p, + request_timeout=config.request_timeout, + max_retries=config.max_retries, + allowed_servers=config.allowed_servers, + start=True, + **kwargs, + ) self.is_subsequence = False else: logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules") diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 441b9d4..90f8c47 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -6,7 +6,7 @@ import logging import random import threading import time -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Collection, Dict, List, Optional, Sequence, Union from weakref import WeakMethod import numpy as np @@ -40,9 +40,10 @@ class RemoteSequenceManager: :param update_period: by default, refresh DHT information once in this many seconds :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1) + :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time - :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds + :param allowed_servers: if defined, send requests only to these servers :param start: start the background thread (see the note below). If false, you will need to start it manually. :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid running redundant sequence managers for the same set of layers. @@ -56,21 +57,30 @@ class RemoteSequenceManager: p2p: P2P, update_period: float = 30, request_timeout: float = 30, + max_retries: Optional[int] = None, min_backoff: float = 1, ban_timeout: float = 15, sequence_info: Optional[RemoteSequenceInfo] = None, rpc_info: Optional[dict] = None, + allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None, banned_peers: Optional[Blacklist] = None, *, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below) start: bool, ): assert len(block_uids) > 0, "Sequences must contain at least one block" self.dht, self.p2p = dht, p2p - self.request_timeout, self.ban_timeout, self.min_backoff = request_timeout, ban_timeout, min_backoff + self.request_timeout, self.max_retries = request_timeout, max_retries + self.ban_timeout, self.min_backoff = ban_timeout, min_backoff self.lock_changes = threading.Lock() self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update)) self.policy = NoSpendingPolicy() self._rpc_info = rpc_info + + if allowed_servers is not None: + allowed_servers = { + PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers + } + self.allowed_servers = allowed_servers self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers if sequence_info is None: @@ -148,6 +158,7 @@ class RemoteSequenceManager: min_backoff=self.min_backoff, sequence_info=self.sequence_info[ix], rpc_info=self._rpc_info, + allowed_servers=self.allowed_servers, banned_peers=self.banned_peers, start=True, ) @@ -169,6 +180,16 @@ class RemoteSequenceManager: for block_info in new_block_infos: if not block_info: continue + + # Apply whitelist, if defined + if self.allowed_servers is not None: + block_info.servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if peer_id in self.allowed_servers + } + + # Remove temporarily banned peers, unless there are no peers left valid_servers = { peer_id: server_info for peer_id, server_info in block_info.servers.items() @@ -260,6 +281,8 @@ class RemoteSequenceManager: except Exception as e: if peer_id is not None and not isinstance(e, P2PHandlerError): self.on_request_failure(peer_id) + if attempt_no + 1 == self.max_retries: + raise delay = self.get_retry_delay(attempt_no) logger.warning( f"Caught exception when gathering information from peer {peer_id} " diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index debcb7b..75d087b 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -95,10 +95,11 @@ async def sequential_forward( except Exception as e: if span is not None: sequence_manager.on_request_failure(span.peer_id) + if attempt_no + 1 == sequence_manager.max_retries: + raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( - f"Caught exception when running forward from block {block_idx} " - f"(retry in {delay:.0f} sec): {repr(e)}" + f"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) await asyncio.sleep(delay) @@ -172,10 +173,11 @@ async def sequential_backward( except Exception as e: if span is not None: sequence_manager.on_request_failure(span.peer_id) + if attempt_no + 1 == sequence_manager.max_retries: + raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( - f"Caught exception when running backward between blocks {span.start}-{span.end} " - f"(retry in {delay:.0f} sec): {repr(e)}" + f"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) await asyncio.sleep(delay) diff --git a/src/petals/server/block_selection.py b/src/petals/server/block_selection.py index 33161b2..1aa39da 100644 --- a/src/petals/server/block_selection.py +++ b/src/petals/server/block_selection.py @@ -16,6 +16,7 @@ class Span: start: int end: int throughput: float + state: ServerState @property def length(self): @@ -43,7 +44,7 @@ def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[ spans[peer_id].start = min(spans[peer_id].start, block) spans[peer_id].end = max(spans[peer_id].start, block + 1) else: - spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput) + spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state) throughputs[block] += server.throughput From 4091db10bf696c039d3477fab7082c015f2db284 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 7 Feb 2023 00:56:58 +0400 Subject: [PATCH 053/168] Lower payload size threshold for stream handlers (#251) Hotfix: we add "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space. --- src/petals/client/remote_forward_backward.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 542ad9c..df97db1 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -108,7 +108,8 @@ async def run_remote_forward( # call RPC on remote server size = sum(t.element_size() * t.nelement() for t in inputs) - forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary + forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary + # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) @@ -150,6 +151,7 @@ async def run_remote_backward( ) size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) - backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary + backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary + # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) return deserialized_grad_inputs From 2a5070aa1ab6383026c0d1e904f36647d4764894 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 7 Feb 2023 01:52:36 +0400 Subject: [PATCH 054/168] Improve reachability logs (#253) --- src/petals/server/reachability.py | 2 +- src/petals/server/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index 7ead055..58caa93 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -71,7 +71,7 @@ def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwar if requests >= max_peers: break - logger.info(f"Direct reachability: {successes}/{requests}") + logger.debug(f"Direct reachability: {successes}/{requests}") return (successes / requests) >= threshold if requests > 0 else None finally: await target_dht.shutdown() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a25fce6..a84d229 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -122,7 +122,7 @@ class Server: if dht_client_mode is None: is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer - logger.info(f"This server will run DHT in {'client' if dht_client_mode else 'full peer'} mode") + logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}") self.dht = DHT( initial_peers=initial_peers, start=True, From 42594e517349b668c4fb4a082ec1324a583340ee Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 17 Feb 2023 07:54:02 +0400 Subject: [PATCH 055/168] Link FAQ in readme (#260) --- README.md | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 0341fe5..ef34a08 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ for input_ids, labels in data_loader: ### Connect your GPU and increase Petals capacity -Run this in an [Anaconda](https://www.anaconda.com) env: +Run this in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia @@ -43,16 +43,18 @@ pip install -U petals python -m petals.cli.run_server bigscience/bloom-petals ``` -Or use our [Docker](https://www.docker.com) image: +Or use our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)): ```bash -sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm \ - learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals +sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ + learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals --port 31330 ``` +📚 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to configure the server to use multiple GPUs, address common issues, etc. + You can also host [BLOOMZ](https://huggingface.co/bigscience/bloomz), a version of BLOOM fine-tuned to follow human instructions in the zero-shot regime — just replace `bloom-petals` with `bloomz-petals`. -🔒 This does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). +🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 💬 If you have any issues or feedback, let us know on [our Discord server](https://discord.gg/D9MwApKgWa)! @@ -64,16 +66,18 @@ Basic tutorials: - Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) - Prompt-tune BLOOM for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) -Example apps built with Petals: - -- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP endpoint): [source code](https://github.com/borzunov/chat.petals.ml) - Useful tools and advanced guides: +- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP endpoint): [source code](https://github.com/borzunov/chat.petals.ml) - [Monitor](http://health.petals.ml) for the public swarm: [source code](https://github.com/borzunov/health.petals.ml) - Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) +Learning more: + +- Frequently asked questions: [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions) +- In-depth system description: [paper](https://arxiv.org/abs/2209.01188) + 📋 If you build an app running BLOOM with Petals, make sure it follows the BLOOM's [terms of use](https://huggingface.co/bigscience/bloom). ## How does it work? @@ -87,23 +91,11 @@ Useful tools and advanced guides:

+ 📚  See FAQ +            📜  Read paper

-## FAQ - -1. **What's the motivation for people to host model layers in the public swarm?** - - People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded). - - Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards. - -2. **Why is the platform named "Petals"?** - - "Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model — [BLOOM](https://huggingface.co/bigscience/bloom). - - While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future. - ## Installation Here's how to install Petals with conda: From 38b071135bc09d937679dc04d0cbd5b24cd17bf2 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 19 Feb 2023 04:34:47 +0400 Subject: [PATCH 056/168] Show visible maddrs for public swarm too (#263) --- src/petals/server/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a84d229..3a5eefd 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -136,9 +136,10 @@ class Server: visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: - logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") + logger.info("Connecting to the public swarm") else: - logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") + logger.info(f"Connecting to a private swarm, initial peers: {initial_peers}") + logger.info(f"Running a server on {visible_maddrs_str}") self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS if device is None: From 55e7dc07a0bafa75286b6e728e9159913b3b3227 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 19 Feb 2023 05:07:21 +0400 Subject: [PATCH 057/168] Limit max delay between retries to 15 min (#264) --- src/petals/client/routing/sequence_manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 90f8c47..a299bc7 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -40,6 +40,7 @@ class RemoteSequenceManager: :param update_period: by default, refresh DHT information once in this many seconds :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1) + :param max_backoff: limit maximal sleep time between retries to this value :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time @@ -59,6 +60,7 @@ class RemoteSequenceManager: request_timeout: float = 30, max_retries: Optional[int] = None, min_backoff: float = 1, + max_backoff: float = 15 * 60, ban_timeout: float = 15, sequence_info: Optional[RemoteSequenceInfo] = None, rpc_info: Optional[dict] = None, @@ -70,7 +72,7 @@ class RemoteSequenceManager: assert len(block_uids) > 0, "Sequences must contain at least one block" self.dht, self.p2p = dht, p2p self.request_timeout, self.max_retries = request_timeout, max_retries - self.ban_timeout, self.min_backoff = ban_timeout, min_backoff + self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff self.lock_changes = threading.Lock() self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update)) self.policy = NoSpendingPolicy() @@ -156,6 +158,7 @@ class RemoteSequenceManager: request_timeout=self.request_timeout, ban_timeout=self.ban_timeout, min_backoff=self.min_backoff, + max_backoff=self.max_backoff, sequence_info=self.sequence_info[ix], rpc_info=self._rpc_info, allowed_servers=self.allowed_servers, @@ -296,7 +299,7 @@ class RemoteSequenceManager: def get_retry_delay(self, attempt_no: int) -> float: if attempt_no == 0: return 0 - return self.min_backoff * 2 ** (attempt_no - 1) + return min(self.min_backoff * 2 ** (attempt_no - 1), self.max_backoff) def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]: """ From fee19e9b9b434dcd79c0d271898aaf3b9d724939 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 19 Feb 2023 05:46:17 +0400 Subject: [PATCH 058/168] Use get_logger(__name__) instead of get_logger(__file__) (#265) --- src/petals/bloom/from_pretrained.py | 2 +- src/petals/bloom/modeling_utils.py | 2 +- src/petals/cli/convert_model.py | 2 +- src/petals/cli/inference_one_block.py | 2 +- src/petals/cli/run_server.py | 2 +- src/petals/client/inference_session.py | 2 +- src/petals/client/remote_generation.py | 2 +- src/petals/client/remote_model.py | 2 +- src/petals/client/remote_sequential.py | 2 +- src/petals/client/routing/sequence_info.py | 2 +- src/petals/client/routing/sequence_manager.py | 2 +- src/petals/client/sequential_autograd.py | 2 +- src/petals/dht_utils.py | 2 +- src/petals/server/backend.py | 2 +- src/petals/server/block_selection.py | 2 +- src/petals/server/handler.py | 2 +- src/petals/server/memory_cache.py | 2 +- src/petals/server/server.py | 2 +- src/petals/server/task_pool.py | 2 +- src/petals/server/throughput.py | 2 +- src/petals/utils/convert_block.py | 2 +- src/petals/utils/disk_cache.py | 2 +- src/petals/utils/version.py | 2 +- tests/test_full_model.py | 2 +- tests/test_remote_sequential.py | 2 +- tests/test_sequence_manager.py | 2 +- 26 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index fa31602..f8e41a7 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -22,7 +22,7 @@ from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import get_block_size from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for -logger = get_logger(__file__) +logger = get_logger(__name__) CLIENT_BRANCH = "main" BLOCK_BRANCH_PREFIX = "block_" diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py index 4e2899c..cb069b8 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/bloom/modeling_utils.py @@ -13,7 +13,7 @@ from hivemind import get_logger from torch import nn from transformers import BloomConfig -logger = get_logger(__file__) +logger = get_logger(__name__) class LMHead(nn.Module): diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index 289c764..6f7499d 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -13,7 +13,7 @@ from transformers.models.bloom.modeling_bloom import BloomModel from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH from petals.client import DistributedBloomConfig -logger = get_logger(__file__) +logger = get_logger(__name__) DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index 9f7c5b4..01ba1ef 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -8,7 +8,7 @@ from transformers.models.bloom.modeling_bloom import build_alibi_tensor from petals.bloom.block import BloomBlock -logger = get_logger(__file__) +logger = get_logger(__name__) logger.warning("inference_one_block will soon be deprecated in favour of tests!") diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 135720d..5fb700d 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -10,7 +10,7 @@ from petals.constants import PUBLIC_INITIAL_PEERS from petals.server.server import Server from petals.utils.version import validate_version -logger = get_logger(__file__) +logger = get_logger(__name__) def main(): diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 902dd0f..24a188a 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -25,7 +25,7 @@ from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, R from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy -logger = get_logger(__file__) +logger = get_logger(__name__) class _ServerInferenceSession: diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 4ea0c9b..bcf59ab 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -15,7 +15,7 @@ from petals.utils.generation_algorithms import ( ) from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint -logger = get_logger(__file__) +logger = get_logger(__name__) class RemoteGenerationMixin: diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 6f8ebf1..e2c0258 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -21,7 +21,7 @@ from petals.client.remote_sequential import RemoteSequential from petals.constants import PUBLIC_INITIAL_PEERS from petals.utils.misc import DUMMY -logger = get_logger(__file__) +logger = get_logger(__name__) class DistributedBloomConfig(BloomConfig): diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 6b9841e..31a33af 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -14,7 +14,7 @@ from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER from petals.utils.misc import DUMMY -logger = get_logger(__file__) +logger = get_logger(__name__) class RemoteSequential(nn.Module): diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index e69cd35..de7eb37 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -6,7 +6,7 @@ from hivemind import get_logger from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState -logger = get_logger(__file__) +logger = get_logger(__name__) T = TypeVar("T") diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index a299bc7..1ca58cf 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -23,7 +23,7 @@ from petals.client.routing.spending_policy import NoSpendingPolicy from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler -logger = get_logger(__file__) +logger = get_logger(__name__) class RemoteSequenceManager: diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 75d087b..b846dfc 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -18,7 +18,7 @@ from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy -logger = get_logger(__file__) +logger = get_logger(__name__) MAX_TOKENS_IN_BATCH = 1024 diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 09aa27a..3542e40 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -15,7 +15,7 @@ from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger import petals.client from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState -logger = get_logger(__file__) +logger = get_logger(__name__) def declare_active_modules( diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index cd8dce4..dc9cebb 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -20,7 +20,7 @@ from petals.server.memory_cache import Handle, MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy -logger = get_logger(__file__) +logger = get_logger(__name__) class TransformerBackend(ModuleBackend): diff --git a/src/petals/server/block_selection.py b/src/petals/server/block_selection.py index 1aa39da..cc050d4 100644 --- a/src/petals/server/block_selection.py +++ b/src/petals/server/block_selection.py @@ -8,7 +8,7 @@ from petals.data_structures import RemoteModuleInfo, ServerState __all__ = ["choose_best_blocks", "should_choose_other_blocks"] -logger = get_logger(__file__) +logger = get_logger(__name__) @dataclass diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index b1c36ed..79376f8 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -32,7 +32,7 @@ from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from petals.utils.misc import DUMMY, is_dummy -logger = get_logger(__file__) +logger = get_logger(__name__) CACHE_TOKENS_AVAILABLE = "cache_tokens_available" diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 0e39cf5..7ea981f 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -18,7 +18,7 @@ from hivemind.utils import TensorDescriptor, get_logger from petals.utils.asyncio import shield_and_wait -logger = get_logger(__file__) +logger = get_logger(__name__) Handle = int diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 3a5eefd..29e9d6b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -31,7 +31,7 @@ from petals.server.throughput import get_dtype_name, get_host_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR -logger = get_logger(__file__) +logger = get_logger(__name__) class Server: diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index 330679c..e027d52 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -12,7 +12,7 @@ from hivemind import get_logger from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.utils.mpfuture import ALL_STATES, MPFuture -logger = get_logger(__file__) +logger = get_logger(__name__) @dataclass(order=True, frozen=True) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 8b6dc9c..ac43759 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -16,7 +16,7 @@ from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_block import convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR -logger = get_logger(__file__) +logger = get_logger(__name__) try: import speedtest diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 0afe641..4938289 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -15,7 +15,7 @@ from transformers.models.bloom.modeling_bloom import BloomAttention from petals.bloom.block import WrappedBloomBlock use_hivemind_log_handler("in_root_logger") -logger = get_logger(__file__) +logger = get_logger(__name__) def convert_block( diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py index eb23477..3217e34 100644 --- a/src/petals/utils/disk_cache.py +++ b/src/petals/utils/disk_cache.py @@ -8,7 +8,7 @@ from typing import Optional import huggingface_hub from hivemind.utils.logging import get_logger -logger = get_logger(__file__) +logger = get_logger(__name__) DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals")) diff --git a/src/petals/utils/version.py b/src/petals/utils/version.py index b992c27..f4a5be1 100644 --- a/src/petals/utils/version.py +++ b/src/petals/utils/version.py @@ -4,7 +4,7 @@ from packaging.version import parse import petals -logger = get_logger(__file__) +logger = get_logger(__name__) def validate_version(): diff --git a/tests/test_full_model.py b/tests/test_full_model.py index d2b272f..1c48c87 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -8,7 +8,7 @@ from transformers.models.bloom import BloomForCausalLM from petals.client.remote_model import DistributedBloomForCausalLM -logger = get_logger(__file__) +logger = get_logger(__name__) @pytest.mark.forked diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index a8e585f..7f49a6e 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -10,7 +10,7 @@ from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER -logger = get_logger(__file__) +logger = get_logger(__name__) @pytest.mark.forked diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 29562c3..7c175a8 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -10,7 +10,7 @@ from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER -logger = get_logger(__file__) +logger = get_logger(__name__) @pytest.mark.forked From a2e7f27a5a49939236311968b0749f4220d1ae49 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 19 Feb 2023 07:00:16 +0400 Subject: [PATCH 059/168] Improve "connect your GPU" message (#266) --- src/petals/client/routing/sequence_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 1ca58cf..6899fd1 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -365,8 +365,8 @@ def maybe_log_traceback(exc: Exception): class MissingBlocksError(RuntimeError): def __init__(self, block_indices: Union[int, Sequence[int]]): super().__init__( - f"No servers holding blocks {block_indices} are online.\n" - f"You can check the public swarm's state at http://health.petals.ml\n\n" - f"If there are not enough servers, please consider connecting your own GPU:\n" - f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity" + f"No servers holding blocks {block_indices} are online. " + f"You can check the public swarm's state at http://health.petals.ml " + f"If there are not enough servers, please connect your GPU: " + f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity " ) From fd9400b392d57d6ef16253e74aab0c60c82227a5 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 21 Feb 2023 06:11:53 +0400 Subject: [PATCH 060/168] Fix use_chunked_forward="auto" on non-x86_64 machines (#267) Import of cpufeature may crash on non-x86_64 machines, so this PR makes the client import it only if necessary. --- src/petals/bloom/modeling_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py index cb069b8..eddbb9d 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/bloom/modeling_utils.py @@ -4,11 +4,12 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e See commit history for authorship. """ +import platform + import psutil import torch import torch.nn.functional as F import torch.utils.checkpoint -from cpufeature import CPUFeature from hivemind import get_logger from torch import nn from transformers import BloomConfig @@ -29,9 +30,15 @@ class LMHead(nn.Module): self.use_chunked_forward = config.use_chunked_forward if self.use_chunked_forward == "auto": - # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward(). - # Otherwise, it's ~8x slower. - self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]) + if platform.machine() == "x86_64": + # Import of cpufeature may crash on non-x86_64 machines + from cpufeature import CPUFeature + + # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward(). + # Otherwise, it's ~8x slower. + self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]) + else: + self.use_chunked_forward = True self.chunked_forward_step = config.chunked_forward_step self._bf16_warning_shown = False From fb2583b682337f2becd65c4f140a464a4469034f Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 27 Feb 2023 12:28:01 +0300 Subject: [PATCH 061/168] Use inference mode in _MergedInferenceStep (#275) --- src/petals/server/backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index dc9cebb..4464e7c 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -159,6 +159,7 @@ class _MergedInferenceStep: def __init__(self, backends: Dict[ExpertUID, TransformerBackend]): self.backends = backends + @torch.inference_mode() def __call__( self, hidden_states: torch.Tensor, From aae1f4f368504bc39c2e46c61cee1491432721eb Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 27 Feb 2023 16:43:06 +0400 Subject: [PATCH 062/168] Increase default request_timeout (#276) This PR increases `request_timeout`, since the previous default of 30 sec is not enough for many use cases. Previously, we kept the request timeout low since we assumed that the server could freeze on dial if the target peer is behind a firewall. However, apparently, it won't freeze because libp2p has its own [dial timeout](https://github.com/libp2p/go-libp2p/blob/v0.26.0/core/network/context.go#L11). --- src/petals/client/remote_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index e2c0258..2ec127f 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -34,7 +34,7 @@ class DistributedBloomConfig(BloomConfig): dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) daemon_startup_timeout: int = 30 dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models - request_timeout: int = 30 # a number of seconds for waiting result from each node + request_timeout: int = 3 * 60 # a number of seconds for waiting result from each node max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) allowed_servers: Optional[ Collection[Union[str, hivemind.PeerID]] From c519bffc59f957cac969504306ad313340e85b9c Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 1 Mar 2023 13:04:21 +0400 Subject: [PATCH 063/168] Bump version to 1.1.3 (#278) --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index ec44076..1a34085 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,6 +1,6 @@ from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.1.2" +__version__ = "1.1.3" _initialize_logs() From 793726b041d5d4b9622ef70e84fdf93ab6cbdc3d Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 12 Mar 2023 22:49:04 +0100 Subject: [PATCH 064/168] Speed up loading blocks using init with meta weights (#285) * Init WrappedBloomBlock with meta weights --------- Co-authored-by: Alexander Borzunov --- pyproject.toml | 3 +- src/petals/bloom/from_pretrained.py | 26 +++++++++------ src/petals/server/block_utils.py | 2 +- tests/test_aux_functions.py | 4 +-- tests/test_block_exact_match.py | 51 +++++++++++++++++++++++++++-- tests/test_chained_calls.py | 2 +- tests/test_full_model.py | 2 +- tests/test_remote_sequential.py | 4 +-- tests/test_sequence_manager.py | 2 +- tests/test_server_stats.py | 2 +- tests/test_tensor_parallel.py | 2 +- 11 files changed, 77 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6f5197..cfc991c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,5 @@ profile = "black" line_length = 120 combine_as_imports = true combine_star = true -known_local_folder = ["tests", "cli"] \ No newline at end of file +known_local_folder = ["tests", "cli"] +known_first_party = ["test_utils"] diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index f8e41a7..9f1d12b 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -13,6 +13,8 @@ import time from typing import Optional, OrderedDict, Union import torch +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device from hivemind.utils.logging import get_logger from transformers.modeling_utils import WEIGHTS_NAME from transformers.models.bloom.configuration_bloom import BloomConfig @@ -38,13 +40,16 @@ def load_pretrained_block( max_disk_space: Optional[int] = None, ) -> WrappedBloomBlock: """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it.""" + assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" if config is None: config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR - block = WrappedBloomBlock(config) + with init_empty_weights(): + block = WrappedBloomBlock(config) + state_dict = _load_state_dict( converted_model_name_or_path, block_index, @@ -54,16 +59,17 @@ def load_pretrained_block( max_disk_space=max_disk_space, ) - if torch_dtype == "auto": - with torch.no_grad(): - for name, param in block.named_parameters(): - assert name in state_dict, f"{name} not in state dict" - param.data = param.data.to(state_dict[name].dtype) - else: - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - block = block.to(dtype=torch_dtype) - + # dummy load, check that keys match report = block.load_state_dict(state_dict, strict=True) + assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" + + for param_name, _ in block.named_parameters(): + assert param_name in state_dict, f"{param_name} not in state dict" + param = state_dict[param_name] + if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + param = param.to(torch_dtype) + set_module_tensor_to_device(block, param_name, "cpu", value=param) + logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}") return block diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index eca7143..fd39ad6 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -30,7 +30,7 @@ def get_block_size( dtype is not None and load_in_8bit is not None ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' - with init_empty_weights(): + with init_empty_weights(include_buffers=True): block = WrappedBloomBlock(config) n_params = sum(param.numel() for param in block.parameters()) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 1986f0a..6909ccf 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,9 +1,9 @@ import pytest import torch -from test_utils import MODEL_NAME from petals.client import DistributedBloomConfig -from petals.server.throughput import measure_compute_rps, measure_network_rps +from petals.server.throughput import measure_compute_rps +from test_utils import MODEL_NAME @pytest.mark.forked diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 664f255..d2fbdde 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -1,15 +1,18 @@ import random +from typing import Union import hivemind import pytest import torch -from test_utils import * +from transformers.models.bloom.configuration_bloom import BloomConfig -from petals.bloom.from_pretrained import load_pretrained_block +from petals.bloom.block import WrappedBloomBlock +from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block from petals.client import DistributedBloomConfig from petals.client.remote_sequential import RemoteTransformerBlock from petals.data_structures import UID_DELIMITER from petals.dht_utils import get_remote_module +from test_utils import * @pytest.mark.forked @@ -41,3 +44,47 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) + + +def _old_load_pretrained_block( + converted_model_name_or_path: str, + block_index: int, + torch_dtype: Union[torch.dtype, str] = "auto", +) -> WrappedBloomBlock: + """Load the BLOOM block by directly initializing the weights. + This test is used to check consistency with the previous implementation and can be removed in the future.""" + config = BloomConfig.from_pretrained(converted_model_name_or_path) + + block = WrappedBloomBlock(config) + state_dict = _load_state_dict( + converted_model_name_or_path, + block_index, + config, + cache_dir=None, + ) + + if torch_dtype == "auto": + with torch.no_grad(): + for name, param in block.named_parameters(): + assert name in state_dict, f"{name} not in state dict" + param.data = param.data.to(state_dict[name].dtype) + else: + assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + block = block.to(dtype=torch_dtype) + + block.load_state_dict(state_dict, strict=True) + return block + + +@pytest.mark.forked +def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8): + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + torch.random.manual_seed(0) + inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype) + + block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) + ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) + + outputs = block.forward(inputs)[0] + outputs_ref = ref_block.forward(inputs)[0] + assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 261361b..9a619b7 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -7,12 +7,12 @@ import hivemind import pytest import torch -from test_utils import * from petals.bloom.from_pretrained import load_pretrained_block from petals.client import DistributedBloomConfig from petals.client.remote_sequential import RemoteSequential from petals.dht_utils import get_remote_sequence +from test_utils import * @pytest.mark.forked diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 1c48c87..cef002e 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -2,11 +2,11 @@ import pytest import torch import transformers from hivemind import get_logger -from test_utils import * from transformers.generation import BeamSearchScorer from transformers.models.bloom import BloomForCausalLM from petals.client.remote_model import DistributedBloomForCausalLM +from test_utils import * logger = get_logger(__name__) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 7f49a6e..18b41a1 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -1,14 +1,14 @@ import pytest import torch import torch.nn.functional as F -from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler +from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind.proto import runtime_pb2 -from test_utils import * from petals.bloom.from_pretrained import load_pretrained_block from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER +from test_utils import * logger = get_logger(__name__) diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 7c175a8..9185ef1 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -4,11 +4,11 @@ import time import pytest import torch from hivemind import DHT, get_logger -from test_utils import * from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER +from test_utils import * logger = get_logger(__name__) diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 0f2b3f0..54d6d33 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -3,12 +3,12 @@ import time import hivemind import pytest import torch -from test_utils import * from petals.client import DistributedBloomConfig from petals.data_structures import UID_DELIMITER from petals.dht_utils import get_remote_sequence from petals.server.handler import CACHE_TOKENS_AVAILABLE +from test_utils import * @pytest.mark.forked diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 9d3ba59..84fcab4 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -5,9 +5,9 @@ import torch import transformers from tensor_parallel import TensorParallel from tensor_parallel.slicing_configs import get_bloom_config -from test_utils import MODEL_NAME from petals.bloom.from_pretrained import load_pretrained_block +from test_utils import MODEL_NAME @pytest.mark.forked From 8dab37c1a90e1aeff2bd98d759509d333eca4ac6 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 13 Mar 2023 05:55:27 +0400 Subject: [PATCH 065/168] Add benchmarks to readme (#284) --- README.md | 149 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 97 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index ef34a08..182be65 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@


Run 100B+ language models at home, BitTorrent-style.
- Fine-tuning and inference up to 10x faster than offloading

+ Fine-tuning and inference up to 10x faster than offloading


@@ -83,8 +83,8 @@ Learning more: ## How does it work? - Petals runs large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. -- Inference runs at ≈ 1 sec per step (token) — 10x faster than possible with offloading, enough for chatbots and other interactive apps. Parallel inference reaches hundreds of tokens/sec. -- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. You get the comforts of an API with the flexibility of PyTorch. +- Single-batch inference runs at ≈ 1 sec per step (token) — [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](http://chat.petals.ml) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. +- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.

@@ -98,61 +98,106 @@ Learning more: ## Installation -Here's how to install Petals with conda: +Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/distribution) on Linux: ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install -U petals ``` -This script uses Anaconda to install CUDA-enabled PyTorch. -If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution). -If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/). -If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). - -__System requirements:__ Petals only supports Linux for now. If you don't have a Linux machine, consider running Petals in Docker (see our [image](https://hub.docker.com/r/learningathome/petals)) or, in case of Windows, in WSL2 ([read more](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)). CPU is enough to run a client, but you probably need a GPU to run a server efficiently. - -## 🛠️ Development - -Petals uses pytest with a few plugins. To install them, run: - -```bash -conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -git clone https://github.com/bigscience-workshop/petals.git && cd petals -pip install -e .[dev] -``` - -To run minimalistic tests, you need to make a local swarm with a small model and some servers. You may find more information about how local swarms work and how to run them in [this tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm). - -```bash -export MODEL_NAME=bloom-testing/test-bloomd-560m-main - -python -m petals.cli.run_server $MODEL_NAME --block_indices 0:12 \ - --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --new_swarm &> server1.log & -sleep 5 # wait for the first server to initialize DHT - -python -m petals.cli.run_server $MODEL_NAME --block_indices 12:24 \ - --initial_peers SEE_THE_OUTPUT_OF_THE_1ST_PEER &> server2.log & - -tail -f server1.log server2.log # view logs for both servers -``` - -Then launch pytest: - -```bash -export MODEL_NAME=bloom-testing/test-bloomd-560m-main REF_NAME=bigscience/bloom-560m -export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g -PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v -``` - -After you're done, you can terminate the servers and ensure that no zombie processes are left with `pkill -f petals.cli.run_server && pkill -f p2p`. - -The automated tests use a more complex server configuration that can be found [here](https://github.com/bigscience-workshop/petals/blob/main/.github/workflows/run-tests.yaml). - -### Code style - -We use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests. -Before committing your code, simply run `black . && isort .` and you will be fine. +If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). + +See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client). + +## ⏱️ Benchmarks + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NetworkSingle-batch inference
(steps/s)
Parallel forward
(tokens/s)
BandwidthRound-trip
latency
Sequence lengthBatch size
1282048164
Offloading, max. possible speed on 1x A100 1
256 Gbit/s0.180.182.7170.3
128 Gbit/s0.090.092.4152.8
Petals on 14 heterogeneous servers across Europe and North America 2
Real world0.830.7932.6179.4
Petals on 3 servers, with one A100 each 3
1 Gbit/s< 5 ms1.711.5470.0253.6
100 Mbit/s< 5 ms1.661.4956.4182.0
100 Mbit/s100 ms1.231.1119.7112.2
+ +1 **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch. + +2 **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 100–1000 Mbit/s. 4 servers operate from under firewalls. + +3 **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU. + +We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf). + +## 🛠️ Contributing + +Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing. ## 📜 Citation From a7d3d021948409c71f95f7c8df33237fa6283a8f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 13 Mar 2023 06:21:09 +0400 Subject: [PATCH 066/168] Fix invalid author email in setup.cfg (#287) --- .github/workflows/run-tests.yaml | 4 ++-- README.md | 2 +- setup.cfg | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index c1a01e8..5e71dba 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -18,7 +18,7 @@ jobs: id: cache-model uses: actions/cache@v3 with: - path: ~/.dummy + path: ~/converted_ok key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - name: Set up Python if: steps.cache-model.outputs.cache-hit != 'true' @@ -52,7 +52,7 @@ jobs: export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \ - --resize_token_embeddings 50000 + --resize_token_embeddings 50000 && touch ~/converted_ok run-tests: runs-on: ubuntu-latest diff --git a/README.md b/README.md index 182be65..15e27ce 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ If you don't use Anaconda, you can install PyTorch in [any other way](https://py See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client). -## ⏱️ Benchmarks +## Benchmarks diff --git a/setup.cfg b/setup.cfg index a05ae6b..03335a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ name = petals version = attr: petals.__version__ author = Petals Developers -author_email = petals-dev@googlegroups.com +author_email = petals-devs@googlegroups.com description = Easy way to efficiently run 100B+ language models without high-end GPUs long_description = file: README.md long_description_content_type = text/markdown From e0cef7375785c7433e2803509162fc7dfb2791d6 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 15 Mar 2023 17:21:30 +0400 Subject: [PATCH 067/168] Hotfix: Increase daemon_startup_timeout (#292) For some reasons, right now 15 sec is not enough to connect to the bootstrap peers in the public swarm, as reported by multiple users and observed by me. Increasing it to 120 sec until we find the root cause of the issue. --- src/petals/cli/run_server.py | 5 +++++ src/petals/client/remote_model.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 5fb700d..57761fd 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -47,6 +47,9 @@ def main(): parser.add_argument('--announce_maddrs', nargs='+', required=False, help='Visible multiaddrs the host announces for external connections from other peers') + parser.add_argument('--daemon_startup_timeout', type=float, default=120, + help='Timeout for the libp2p daemon connecting to initial peers') + parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication') parser.add_argument('--num_handlers', type=int, default=8, required=False, @@ -167,6 +170,8 @@ def main(): assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)" announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"] + args["startup_timeout"] = args.pop("daemon_startup_timeout") + if args.pop("increase_file_limit"): increase_file_limit() diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 2ec127f..3b16abe 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -32,7 +32,7 @@ class DistributedBloomConfig(BloomConfig): initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) - daemon_startup_timeout: int = 30 + daemon_startup_timeout: int = 120 # timeout for the libp2p daemon connecting to initial peers dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models request_timeout: int = 3 * 60 # a number of seconds for waiting result from each node max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) From 987f4d2b2fa0fbd13f73d3e3913d51315a340faf Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 29 Mar 2023 00:20:29 +0300 Subject: [PATCH 068/168] Update bitsandbytes, hivemind, transformers (#290) - new bitsandbytes supports newer *and* older GPUs - new hivemind supports a better bfloat16 codec Co-authored-by: Alexander Borzunov --- setup.cfg | 6 +++--- src/petals/bloom/block.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 03335a2..ba3bedc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,12 +32,12 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.34.0 + bitsandbytes==0.37.1 accelerate==0.15.0 huggingface-hub==0.11.1 - transformers==4.25.1 + transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 - hivemind==1.1.5 + hivemind==1.1.6 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/bloom/block.py b/src/petals/bloom/block.py index f4d50be..78171cf 100644 --- a/src/petals/bloom/block.py +++ b/src/petals/bloom/block.py @@ -8,10 +8,13 @@ from typing import Optional, Tuple import torch.nn.quantized.dynamic.modules.linear import transformers +from packaging import version from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): - assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1" + assert ( + version.parse("4.26.0") < version.parse(transformers.__version__) < version.parse("5.0.0") + ), "Please install a proper transformers version: pip install transformers>=4.26.0,<5.0.0" class WrappedBloomBlock(BloomBlock): From 2116df08bcbdaff48d185af3554f15582d615d45 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 29 Mar 2023 04:21:37 +0400 Subject: [PATCH 069/168] Fix deps, enable 8-bit by default for TP (#298) This PR fixes issues of #290: - hivemind bfloat16 codec crashed on dummy tensors (with 0 elements), see https://github.com/learning-at-home/hivemind/pull/560 (this PR makes Petals depend on the latest hivemind version from the repo, it's temporary) - transformers version check mismatched with the version allowed in `setup.cfg` Also: - This PR enables 8-bit by default for TP. Even though TP in 8-bit may be slower, we currently prefer to host more blocks to increase the network's stability. --- setup.cfg | 2 +- src/petals/bloom/block.py | 4 ++-- src/petals/server/server.py | 6 ------ 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index ba3bedc..c485cd5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = huggingface-hub==0.11.1 transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 - hivemind==1.1.6 + hivemind @ git+https://github.com/learning-at-home/hivemind.git tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/bloom/block.py b/src/petals/bloom/block.py index 78171cf..9037ee4 100644 --- a/src/petals/bloom/block.py +++ b/src/petals/bloom/block.py @@ -13,8 +13,8 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( - version.parse("4.26.0") < version.parse(transformers.__version__) < version.parse("5.0.0") - ), "Please install a proper transformers version: pip install transformers>=4.26.0,<5.0.0" + version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0") + ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0" class WrappedBloomBlock(BloomBlock): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 29e9d6b..4563e28 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -163,12 +163,6 @@ class Server: if load_in_8bit is None: load_in_8bit = device.type == "cuda" - if load_in_8bit and len(self.tensor_parallel_devices) > 1: - load_in_8bit = False - logger.warning( - "Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. " - "You can explicitly set `--load_in_8bit True` to override this" - ) self.load_in_8bit = load_in_8bit logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") From 74d8cda8c4a55c73c61624ffe7bfdb5c79911891 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 29 Mar 2023 04:41:07 +0400 Subject: [PATCH 070/168] Add Python 3.10 to CI (#299) --- .github/workflows/check-style.yaml | 2 +- .github/workflows/run-tests.yaml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/check-style.yaml b/.github/workflows/check-style.yaml index 42e1460..60ea42b 100644 --- a/.github/workflows/check-style.yaml +++ b/.github/workflows/check-style.yaml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v3 with: python-version: 3.8 - uses: isort/isort-action@master diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 5e71dba..3d48d37 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -22,7 +22,7 @@ jobs: key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - name: Set up Python if: steps.cache-model.outputs.cache-hit != 'true' - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: 3.9 - name: Cache dependencies @@ -59,14 +59,14 @@ jobs: needs: convert-model strategy: matrix: - python-version: [ 3.7, 3.8, 3.9 ] + python-version: [ '3.7', '3.8', '3.9', '3.10' ] fail-fast: false timeout-minutes: 15 steps: - name: Checkout uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Cache dependencies From 892fa2386ae8982b06ffcbe6640db95e5fd67b68 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 29 Mar 2023 05:21:16 +0400 Subject: [PATCH 071/168] Remove CustomLinear8bitLt (#297) This became a part of https://github.com/TimDettmers/bitsandbytes/releases/tag/0.37.0. --- src/petals/utils/convert_block.py | 6 +- src/petals/utils/linear8bitlt_patch.py | 334 ------------------------- tests/test_linear8bitlt.py | 108 -------- 3 files changed, 3 insertions(+), 445 deletions(-) delete mode 100644 src/petals/utils/linear8bitlt_patch.py delete mode 100644 tests/test_linear8bitlt.py diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 4938289..b58cd1a 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -1,6 +1,7 @@ """ Tools for converting transformer blocks, applying quantization and/or tensor parallelism """ +import os import re from typing import Sequence @@ -75,17 +76,16 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0): """ # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes + os.environ["BITSANDBYTES_NOWELCOME"] = "1" import bitsandbytes as bnb - from petals.utils.linear8bitlt_patch import CustomLinear8bitLt - for n, module in model.named_children(): if len(list(module.children())) > 0: replace_8bit_linear(module, threshold) if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}" - model._modules[n] = CustomLinear8bitLt( + model._modules[n] = bnb.nn.Linear8bitLt( module.in_features, module.out_features, module.bias is not None, diff --git a/src/petals/utils/linear8bitlt_patch.py b/src/petals/utils/linear8bitlt_patch.py deleted file mode 100644 index 523436f..0000000 --- a/src/petals/utils/linear8bitlt_patch.py +++ /dev/null @@ -1,334 +0,0 @@ -""" -A patch to bitsandbytes 0.34.0 that introduces an option to run backward pass in default (fast) matrix layout. -Authors: modification by @borzunov, original code by @timdettmers. Please disregard commit authors in this file. - -Core idea: layouts apply the same permutation to every tile in the matrix. We can treat this as (batched) gather ops. - Reshape input tensor so that ij-th gather operation op will apply to ij-th elements in each tile. -Prototype: https://colab.research.google.com/drive/1EJ0MKifajXSSVq7O2_QGwtb0l6gRAGrh?usp=sharing -Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#L2130-L2136 -Exact match tests: see $REPO/tests/test_linear8bitlt.py -""" -import dataclasses -import logging -from typing import Optional, Tuple - -import bitsandbytes.functional as F -import torch -from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatMul8bitLt, MatmulLtState, prod -from bitsandbytes.nn import Linear8bitLt - - -def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): - """ - Compute a permutation of indices that invert the specified (tiled) matrix transformation - - :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2] - :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere - :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size - :example: transform_tile function for the turing layout (bitsandbytes.functional as F) - :returns: indices - """ - d1, d2 = tile_size - assert 0 < d1 * d2 < 2**64 - tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2) - # encode each position in tile as a tuple of <= 8 unique bytes - permuted_tile_indices = torch.zeros_like(tile_indices) - for i in range(8): - # select i-th byte, apply transformation and trace where each index ended up - ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256 - sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous() - assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow" - permuted_tile_i = transform_tile(sample_tile_i) - ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128 - permuted_tile_indices += ith_permuted_indices * (256**i) - if d1 * d2 < 256**i: - break # if all indices fit in i bytes, stop early - return permuted_tile_indices - - -def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: - """ - Undo a tiled permutation such as turing or ampere layout - - :param permuted_tensor: torch tensor in a permuted layout - :param tile_indices: reverse transformation indices, from get_inverse_transform_indices - :return: contiguous row-major tensor - """ - (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape - assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" - tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() - outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda - outputs[tile_indices.flatten()] = tensor - outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows) - outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols) - return outputs.reshape(rows, cols).contiguous() - - -# the rest of this file is just a patch to bitsandbytes that modifies Linear8bitLt and dependencies - - -class CustomLinear8bitLt(Linear8bitLt): - def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs): - assert not memory_efficient_backward, "memory_efficient_backward is no longer used" - super().__init__(*args, **kwargs) - old_state, self.state = self.state, CustomMatmulLtState() - self.state.threshold = old_state.threshold - self.state.has_fp16_weights = old_state.has_fp16_weights - self.state.memory_efficient_backward = old_state.memory_efficient_backward - if old_state.threshold > 0.0 and not old_state.has_fp16_weights: - self.state.use_pool = True - - def forward(self, x: torch.Tensor): - self.state.is_training = self.training - if self.weight.CB is not None: - self.init_8bit_state() - - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - out = custom_matmul8bitlt(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB - return out - - -@dataclasses.dataclass(init=True) -class CustomMatmulLtState(MatmulLtState): - tile_indices: Optional[torch.Tensor] = None - force_no_igemmlt: bool = False - - def get_tile_size(self): - assert self.formatB in ( - "col_turing", - "col_ampere", - ), f"please find this assert and manually enter tile size for {self.formatB}" - return (8, 32) if self.formatB == "col_turing" else (32, 32) - - -def custom_matmul8bitlt( - A: torch.Tensor, - B: torch.Tensor, - out: torch.Tensor = None, - state: CustomMatmulLtState = None, - threshold=0.0, - bias=None, -): - state = state or MatmulLtState() - if threshold > 0.0: - state.threshold = threshold - return CustomMatMul8bitLt.apply(A, B, out, bias, state) - - -class CustomMatMul8bitLt(MatMul8bitLt): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - - @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=CustomMatmulLtState): - using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt - # default to pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - ctx.bias = bias - if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) - - # 1. Quantize A - # 2. Quantize B - # 3. Matmul - # 4. Mixed-precision decomposition matmul - # 5. Save state - formatB = state.formatB - input_shape = A.shape - if state.outlier_pool is None: - state.outlier_pool = GlobalOutlierPooler.get_instance() - - # Cast A to fp16 - if A.dtype != torch.float16: - logging.debug(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") - - # 1. Quantize A - if len(A.shape) == 3: - A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) - - if state.threshold > 0.0 and coo_tensorA is not None: - if state.has_fp16_weights: - idx = torch.unique(coo_tensorA.colidx).long() - CA[:, idx] = 0 - CAt[:, idx] = 0 - subA = A[:, idx] - state.subB = B[:, idx].t().contiguous() - state.idx = idx - else: - if state.CxB is None and using_igemmlt: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - else: - if not state.has_fp16_weights and state.CxB is None and using_igemmlt: - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - subA = None - - # 2. Quantize B - if state.has_fp16_weights: - has_grad = True if (getattr(B, "grad", None) is not None) else False - is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) - if is_transposed: - B = B.contiguous() - - if (state.is_training and not has_grad) or state.CxB is None: - state.reset_grads() - ( - CB, - state.CBt, - state.SCB, - state.SCBt, - coo_tensorB, - ) = F.double_quant(B.to(torch.float16)) - if using_igemmlt: - state.CxB, state.SB = F.transform(CB, to_order=formatB) - else: - state.CB = CB - else: - has_grad = False - - if coo_tensorA is not None and not state.has_fp16_weights: - # extract outliers - - outlier_idx = torch.unique(coo_tensorA.colidx) - state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx - if state.CxB is not None: - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - else: - outliers = state.CB[:, state.idx.long()].clone() - - state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) - CA[:, state.idx.long()] = 0 - CAt[:, state.idx.long()] = 0 - subA = A[:, state.idx.long()] - - shapeB = state.SB[0] if state.SB else B.shape - - if len(input_shape) == 3: - output_shape = (input_shape[0], input_shape[1], shapeB[0]) - else: - output_shape = (input_shape[0], shapeB[0]) - - # 3. Matmul - if using_igemmlt: - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - if bias is None or bias.dtype == torch.float16: - # we apply the fused bias here - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A.dtype) - else: # apply bias separately - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A.dtype).add_(bias) - - else: - A_wo_outliers = A.clone() - if state.idx is not None: - A_wo_outliers[:, state.idx.long()] = 0 - output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) - output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0)) - if bias is not None: - output = output.add_(bias) - - # 4. Mixed-precision decomposition matmul - if coo_tensorA is not None and subA is not None: - output += torch.matmul(subA, state.subB) - - # 5. Save state - ctx.state = state - - ctx.formatB = formatB - ctx.grad_shape = input_shape - ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype - - if any(ctx.needs_input_grad[:2]): - ctx.tensors = (CAt, subA) - ctx.tensor_states = (SCAt, state.idx) - else: - ctx.tensors = [None, None] - ctx.tensor_states = (None, None) - ctx.save_for_backward(None, None) - - clone_func = torch.clone if len(output_shape) == 3 else lambda x: x - return clone_func(output.view(output_shape)) - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad - CAt, subA = ctx.tensors - SCAt, idx = ctx.tensor_states - formatB = ctx.formatB - state = ctx.state - grad_A = grad_B = grad_bias = None - - if req_gradBias: - # compute grad_bias first before changing grad_output dtype - grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) - - # Cast grad_output to fp16 - if len(grad_output.shape) == 3: - grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) - if req_gradB: - CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) - if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.matmul(grad_output.t(), subA) - - if req_gradA: - if state.CBt is not None: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - - elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - elif state.CxB is not None: - - if state.tile_indices is None: - order, tile_size = state.formatB, state.get_tile_size() - transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) - with torch.no_grad(): - state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device) - - CB = ( - undo_layout(state.CxB, state.tile_indices) - .to(ctx.dtype_A) - .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - ) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - else: - raise Exception("State must contain either CBt or CB or CxB matrix for backward") - - return grad_A, grad_B, None, grad_bias, None diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py deleted file mode 100644 index f954c67..0000000 --- a/tests/test_linear8bitlt.py +++ /dev/null @@ -1,108 +0,0 @@ -import bitsandbytes as bnb -import pytest -import torch -from bitsandbytes import functional as F - -from petals.utils.linear8bitlt_patch import CustomLinear8bitLt, get_inverse_transform_indices, undo_layout - - -@pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), - reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", -) -def test_layout_exact_match(): - x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda() - for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"): - transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) - tile_indices = get_inverse_transform_indices(transform, tile_size) - cxb = transform(x) - - torch.cuda.synchronize() - restored_x = undo_layout(cxb, tile_indices) - torch.cuda.synchronize() - assert restored_x.is_contiguous() - assert torch.all(torch.eq(restored_x, x)) - - -@pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), - reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", -) -def test_linear_exact_match(): - linear = torch.nn.Linear(1024, 3072) - x = torch.randn(3, 1024, dtype=torch.half) - linear8bitlt = bnb.nn.Linear8bitLt( - linear.in_features, - linear.out_features, - linear.bias is not None, - has_fp16_weights=False, - threshold=6.0, - memory_efficient_backward=True, - ) - linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to( - linear.weight.dtype - ) - linear8bitlt.bias = linear.bias - linear8bitlt.cuda() - - linear_custom = CustomLinear8bitLt( - linear.in_features, - linear.out_features, - linear.bias is not None, - has_fp16_weights=False, - threshold=6.0, - ) - linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False - ).to(linear.weight.dtype) - linear_custom.bias = linear.bias - linear_custom.cuda() - - x_ref = x.clone().cuda().requires_grad_(True) - x_ours = x.clone().cuda().requires_grad_(True) - fx_ref = linear8bitlt(x_ref).float() - grad_proj = torch.randn_like(fx_ref) - (fx_ref * grad_proj).mean().backward() - - fx_ours = linear_custom(x_ours).float() - (fx_ours * grad_proj).mean().backward() - assert torch.equal(fx_ref, fx_ours) - assert torch.allclose(x_ref.grad, x_ours.grad) - assert not linear_custom.state.has_fp16_weights - assert linear_custom.state.CB is None - assert linear_custom.state.CxB is not None - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -def test_linear_no_igemmlt(): - linear = torch.nn.Linear(1024, 3072) - x = torch.randn(3, 1024, dtype=torch.half) - linear_custom = CustomLinear8bitLt( - linear.in_features, - linear.out_features, - linear.bias is not None, - has_fp16_weights=False, - threshold=6.0, - ) - linear_custom.state.force_no_igemmlt = True - - linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False - ).to(linear.weight.dtype) - linear_custom.bias = linear.bias - linear_custom.cuda() - linear.half().cuda() - - x_ref = x.clone().cuda().requires_grad_(True) - x_ours = x.clone().cuda().requires_grad_(True) - fx_ref = linear(x_ref).float() - grad_proj = torch.randn_like(fx_ref) - (fx_ref * grad_proj).mean().backward() - - fx_ours = linear_custom(x_ours).float() - (fx_ours * grad_proj).mean().backward() - assert torch.allclose(fx_ref, fx_ours, atol=0.02) - assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) - assert not linear_custom.state.has_fp16_weights - assert linear_custom.state.CB is not None - assert linear_custom.state.CxB is None From 6c6150f6844194e2e7412d13b969642a109d74c4 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 31 Mar 2023 16:39:48 +0400 Subject: [PATCH 072/168] Remove use_auto_relay=True in client (#300) `use_auto_relay=True` makes the libp2p daemon look for relays to become reachable if we are behind NAT/firewall. However, being reachable is not necessary for the Petals client, and we should not spend the relays' capacity on this. --- src/petals/client/remote_model.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 3b16abe..937cd9c 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -102,19 +102,15 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): assert len(self.h) == 0 config.n_layer = n_layer - dht = ( - config.dht - if config.dht is not None - else hivemind.DHT( + dht = config.dht + if dht is None: + dht = hivemind.DHT( initial_peers=config.initial_peers, client_mode=True, num_workers=n_layer, startup_timeout=config.daemon_startup_timeout, start=True, - use_relay=True, - use_auto_relay=True, ) - ) assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" self.h = RemoteSequential( config, From 21c3526ec1b82e274fde30b12d6367089b77b992 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 12 Apr 2023 21:38:43 +0400 Subject: [PATCH 073/168] Start SequenceManager's thread only after first .make_sequence() (#301) **Why?** - We'd like to avoid excess threads for the original sequence manager in case if we only use its slices (e.g. when we add adapters or need only a subset of model blocks): - If we create a sequence manager just before a fork (e.g. in a web app backend or a multi-thread benchmark), we'd like to avoid excess threads in the original process and only use this thread in child processes where we actually call `.make_sequence()`. --- src/petals/client/remote_sequential.py | 1 - src/petals/client/routing/sequence_manager.py | 55 ++++++++----------- src/petals/dht_utils.py | 29 ++++++---- src/petals/server/server.py | 4 +- tests/test_remote_sequential.py | 2 +- tests/test_sequence_manager.py | 2 +- 6 files changed, 45 insertions(+), 48 deletions(-) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 31a33af..788805d 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -48,7 +48,6 @@ class RemoteSequential(nn.Module): request_timeout=config.request_timeout, max_retries=config.max_retries, allowed_servers=config.allowed_servers, - start=True, **kwargs, ) self.is_subsequence = False diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 6899fd1..25c68ef 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -10,7 +10,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Union from weakref import WeakMethod import numpy as np -from hivemind import DHT, P2P, MSGPackSerializer, PeerID +from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time from hivemind.dht.node import Blacklist from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import P2PHandlerError @@ -66,8 +66,7 @@ class RemoteSequenceManager: rpc_info: Optional[dict] = None, allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None, banned_peers: Optional[Blacklist] = None, - *, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below) - start: bool, + # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below) ): assert len(block_uids) > 0, "Sequences must contain at least one block" self.dht, self.p2p = dht, p2p @@ -75,6 +74,7 @@ class RemoteSequenceManager: self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff self.lock_changes = threading.Lock() self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update)) + self._thread_start_lock = threading.Lock() self.policy = NoSpendingPolicy() self._rpc_info = rpc_info @@ -87,23 +87,16 @@ class RemoteSequenceManager: if sequence_info is None: self.sequence_info = RemoteSequenceInfo.make_empty(block_uids) - self.update(wait=False) + + # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records + # in the first _update() instead of the latest ones. This makes the first .update() faster. + petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True) + self._need_latest_infos = False else: self.sequence_info = sequence_info assert block_uids == sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch - - if start: - self.run_in_background() - - def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None: - """ - Starts the updater thread in a background. if await_ready, this method will wait until sequence manager - is ready to process incoming requests or for :timeout: seconds max. - """ - self._thread.start() - if await_ready: - self._thread.ready.wait(timeout) + self._need_latest_infos = True def make_sequence( self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random" @@ -115,10 +108,10 @@ class RemoteSequenceManager: :param end_index: optional index of the last module (non-inclusive), default = after last of block uids :param mode: either random or fastest """ - if not self.is_alive(): - logger.error("Using a sequence manager that is not running: it has either crashed or never started") + with self._thread_start_lock: + if not self.is_alive(): + self._thread.start() if not self.ready.is_set(): - logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready") self.update(wait=True) # this will await an existing update or trigger a new one (if not updating) end_index = end_index if end_index is not None else len(self) @@ -163,7 +156,6 @@ class RemoteSequenceManager: rpc_info=self._rpc_info, allowed_servers=self.allowed_servers, banned_peers=self.banned_peers, - start=True, ) def update(self, *, wait: bool): @@ -178,8 +170,10 @@ class RemoteSequenceManager: for attempt_no in itertools.count(): try: new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, expiration_time=float("inf") + self.dht, self.block_uids, latest=self._need_latest_infos ) + self._need_latest_infos = True # All future _update() should use latest infos + for block_info in new_block_infos: if not block_info: continue @@ -259,6 +253,10 @@ class RemoteSequenceManager: def rpc_info(self): """Return the rpc_info queried from one of the servers that hold the first block""" if self._rpc_info is None: + with self._thread_start_lock: + if not self.is_alive(): + self._thread.start() + for attempt_no in itertools.count(): peer_id = None try: @@ -320,18 +318,11 @@ class _SequenceManagerUpdateThread(threading.Thread): self.ref_update_manager = ref_update_manager self.ready = threading.Event() self.trigger = threading.Event() - self.last_update_time = -float("inf") self.update_period = update_period self.should_shutdown = False def run(self) -> None: while not self.should_shutdown: - self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time))) - - if self.should_shutdown: - logger.debug(f"{self.__class__.__name__} is shutting down") - break - update_manager = self.ref_update_manager() if update_manager is None: logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists") @@ -345,16 +336,18 @@ class _SequenceManagerUpdateThread(threading.Thread): finally: del update_manager + self.trigger.wait(self.update_period) + logger.debug(f"{self.__class__.__name__} thread exited") def shutdown(self, timeout: Optional[float] = None): self.should_shutdown = True self.trigger.set() - self.join(timeout) + if self.is_alive(): + self.join(timeout) def __del__(self): - if self.is_alive(): - self.shutdown() + self.shutdown() def maybe_log_traceback(exc: Exception): diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 3542e40..06c30eb 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -93,7 +93,7 @@ async def _get_remote_sequence( ) -> petals.client.RemoteSequential: uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)] p2p = await dht.replicate_p2p() - manager = petals.client.RemoteSequenceManager(dht, uids, p2p, start=True) + manager = petals.client.RemoteSequenceManager(dht, uids, p2p) return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager) @@ -124,7 +124,7 @@ async def _get_remote_module( single_uid = isinstance(uid_or_uids, ModuleUID) uids = [uid_or_uids] if single_uid else uid_or_uids p2p = await dht.replicate_p2p() - managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p, start=True) for uid in uids) + managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids) modules = [ petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers @@ -133,21 +133,26 @@ async def _get_remote_module( def get_remote_module_infos( - dht: DHT, uid_or_uids: Union[ModuleUID, Sequence[ModuleUID]], expiration_time: Optional[DHTExpiration] = None -) -> List[Optional[RemoteModuleInfo]]: - single_uid = isinstance(uid_or_uids, ModuleUID) - uids = [uid_or_uids] if single_uid else uid_or_uids - infos = dht.run_coroutine( - partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), - return_future=False, + dht: DHT, + uids: Sequence[ModuleUID], + expiration_time: Optional[DHTExpiration] = None, + *, + latest: bool = False, + return_future: bool = False, +) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: + return dht.run_coroutine( + partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest), + return_future=return_future, ) - return infos[0] if single_uid else infos async def _get_remote_module_infos( - dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration] + dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool ) -> List[Optional[RemoteModuleInfo]]: - if expiration_time is None: + if latest: + assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" + expiration_time = math.inf + elif expiration_time is None: expiration_time = get_dht_time() num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 4563e28..4f2a645 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -324,14 +324,14 @@ class Server: # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, # this delay decreases the probability of a race condition while choosing the best blocks to serve. time.sleep(random.random() * 2 * self.mean_block_selection_delay) - module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf) + module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True) return block_selection.choose_best_blocks(self.num_blocks, module_infos) def _should_choose_other_blocks(self) -> bool: if self.strict_block_indices is not None: return False - module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf) + module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True) return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality) def shutdown(self): diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 18b41a1..a69c42f 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -48,7 +48,7 @@ def test_remote_sequential(): # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] lossy_sequential = RemoteSequential( - config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p, start=True) + config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p) ) test_inputs.grad = None diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 9185ef1..f0b61cf 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -26,7 +26,7 @@ def test_sequence_manager_basics(mode: str): sequential = RemoteSequential( config, dht, - sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True), + sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt), ) sequence = sequential.sequence_manager.make_sequence(mode=mode) From 35662b4a16608a7fd6e78b776ae8cd72f916ca70 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 12 Apr 2023 23:07:29 +0400 Subject: [PATCH 074/168] Require bitsandbytes == 0.38.0.post2, hivemind == 1.1.7 (#302) In particular, this PR fixes 8-bit support on nvidia16 GPUs (such as 1660) by including https://github.com/TimDettmers/bitsandbytes/pull/292. This support was requested multiple times on Discord. --- setup.cfg | 4 ++-- src/petals/cli/run_server.py | 2 +- src/petals/client/remote_model.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index c485cd5..ca1f4c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,12 +32,12 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.37.1 + bitsandbytes==0.38.0.post2 accelerate==0.15.0 huggingface-hub==0.11.1 transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 - hivemind @ git+https://github.com/learning-at-home/hivemind.git + hivemind==1.1.7 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 57761fd..5e7efb5 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -47,7 +47,7 @@ def main(): parser.add_argument('--announce_maddrs', nargs='+', required=False, help='Visible multiaddrs the host announces for external connections from other peers') - parser.add_argument('--daemon_startup_timeout', type=float, default=120, + parser.add_argument('--daemon_startup_timeout', type=float, default=60, help='Timeout for the libp2p daemon connecting to initial peers') parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication') diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 937cd9c..dc987e4 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -32,7 +32,7 @@ class DistributedBloomConfig(BloomConfig): initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) - daemon_startup_timeout: int = 120 # timeout for the libp2p daemon connecting to initial peers + daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models request_timeout: int = 3 * 60 # a number of seconds for waiting result from each node max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) From 5c0b4286b2f504e4f91d7f1436bb7e71a841c0b7 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 13 Apr 2023 00:00:35 +0400 Subject: [PATCH 075/168] Suggest commands for Docker first (#304) --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 15e27ce..73d952b 100644 --- a/README.md +++ b/README.md @@ -35,19 +35,19 @@ for input_ids, labels in data_loader: ### Connect your GPU and increase Petals capacity -Run this in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): +Run our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)): ```bash -conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install -U petals -python -m petals.cli.run_server bigscience/bloom-petals +sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ + learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals --port 31330 ``` -Or use our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)): +Or run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): ```bash -sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ - learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals --port 31330 +conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia +pip install -U petals +python -m petals.cli.run_server bigscience/bloom-petals ``` 📚 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to configure the server to use multiple GPUs, address common issues, etc. From 98be9ffe4cd8be2461d356672a5c2721e6d8c04b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 13 Apr 2023 01:05:35 +0400 Subject: [PATCH 076/168] Relax the rest of Hugging Face dependencies (#305) --- setup.cfg | 4 ++-- src/petals/cli/convert_model.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index ca1f4c7..09182c6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,8 +33,8 @@ python_requires = >=3.7 install_requires = torch>=1.12 bitsandbytes==0.38.0.post2 - accelerate==0.15.0 - huggingface-hub==0.11.1 + accelerate>=0.15.0,<1.0.0 + huggingface-hub>=0.11.1,<1.0.0 transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 hivemind==1.1.7 diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index 6f7499d..95b08e4 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -6,7 +6,7 @@ import torch.backends.quantized import torch.nn as nn import transformers from hivemind.utils.logging import get_logger -from huggingface_hub import Repository +from huggingface_hub import HfApi, Repository from tqdm.auto import tqdm from transformers.models.bloom.modeling_bloom import BloomModel @@ -66,6 +66,8 @@ def main(): ) os.makedirs(args.output_path, exist_ok=True) + api = HfApi(token=args.use_auth_token) + api.create_repo(args.output_repo, repo_type="model", exist_ok=True) repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token) repo.git_pull() From c0e0e1319dfae0c307cdfc8cb86825b827bc598e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 13 Apr 2023 14:41:54 +0400 Subject: [PATCH 077/168] Force transformers to use config.torch_dtype by default (#307) --- src/petals/client/remote_model.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index dc987e4..d67d4bf 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -71,20 +71,33 @@ def force_non_empty_weights(): nn.Module.register_parameter = possibly_patched_register_parameter -class _LowCPUMemoryMixin: +class _FromPretrainedDefaultsMixin: @classmethod - def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + *args, + low_cpu_mem_usage: Optional[bool] = None, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + **kwargs, + ): if low_cpu_mem_usage is None: low_cpu_mem_usage = True - return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) + if torch_dtype is None: + # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, + # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. + torch_dtype = "auto" + return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs) from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( "low_cpu_mem_usage(`bool`, *optional*)", "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", + ).replace( + "torch_dtype (`str` or `torch.dtype`, *optional*)", + 'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)', ) -class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): +class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel): """BloomModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [ @@ -218,7 +231,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): ) -class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM): +class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM): """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = ( @@ -256,7 +269,7 @@ class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, Blo self.lm_head.bias[...] = new_lm_head.bias -class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification): +class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification): _keys_to_ignore_on_load_missing = ( BloomForSequenceClassification._keys_to_ignore_on_load_missing + DistributedBloomModel._keys_to_ignore_on_load_missing From 93c4eba5d175d4886a7ff6272b85624e871b1d78 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 21 Apr 2023 05:41:01 +0400 Subject: [PATCH 078/168] Bump version to 1.1.4 (#306) --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 1a34085..373c40a 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,6 +1,6 @@ from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.1.3" +__version__ = "1.1.4" _initialize_logs() From 454c193863eed5d06ccf2c33f5187c6313ffd1bb Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 25 Apr 2023 17:20:19 +0400 Subject: [PATCH 079/168] Fix OOMs happening in case of accelerate >= 0.16.0 (#310) - After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()` - In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (https://github.com/huggingface/accelerate/pull/920) - Because of that, blocks and attention caches used float32, which caused OOMs - This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`) --- setup.cfg | 2 +- src/petals/bloom/from_pretrained.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 09182c6..786c8f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ python_requires = >=3.7 install_requires = torch>=1.12 bitsandbytes==0.38.0.post2 - accelerate>=0.15.0,<1.0.0 + accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 9f1d12b..4748b41 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -68,7 +68,7 @@ def load_pretrained_block( param = state_dict[param_name] if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): param = param.to(torch_dtype) - set_module_tensor_to_device(block, param_name, "cpu", value=param) + set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}") return block From 8f6342a8611f2b6809808d97862132a01f4036b9 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 7 May 2023 13:41:13 +0400 Subject: [PATCH 080/168] Refactor RemoteSequenceManager (#309) This PR: 1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.** The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way. 2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.** `dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided. 3. **Simplifies retry logic.** Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps. 4. **Removes deprecated `RemoteTransformerBlock`.** `RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due. 5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.** This functions duplicate the functionality of the `RemoteSequential` constructor. 6. (minor) **Removes `RemoteSequential.is_subsequence` flag.** This flag worked incorrectly and was never used. I am removing it for the sake of simplicity. --- README.md | 2 + src/petals/client/__init__.py | 2 +- src/petals/client/inference_session.py | 13 +- src/petals/client/remote_model.py | 31 +- src/petals/client/remote_sequential.py | 81 ++--- src/petals/client/routing/sequence_info.py | 4 +- src/petals/client/routing/sequence_manager.py | 299 +++++++++--------- src/petals/client/sequential_autograd.py | 18 +- src/petals/dht_utils.py | 61 ---- tests/test_block_exact_match.py | 13 +- tests/test_chained_calls.py | 13 +- tests/test_remote_sequential.py | 7 +- tests/test_sequence_manager.py | 5 +- tests/test_server_stats.py | 15 +- 14 files changed, 210 insertions(+), 354 deletions(-) diff --git a/README.md b/README.md index 73d952b..f157cc0 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,8 @@ See the instructions for macOS and Windows, the full requirements, and troublesh ## Benchmarks +The benchmarks below are for BLOOM-176B: +
diff --git a/src/petals/client/__init__.py b/src/petals/client/__init__.py index b728962..5ff26bc 100644 --- a/src/petals/client/__init__.py +++ b/src/petals/client/__init__.py @@ -5,6 +5,6 @@ from petals.client.remote_model import ( DistributedBloomForSequenceClassification, DistributedBloomModel, ) -from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock +from petals.client.remote_sequential import RemoteSequential from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 24a188a..93700f9 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -8,7 +8,6 @@ from typing import AsyncIterator, List, Optional import torch from hivemind import ( - P2P, MSGPackSerializer, anext, deserialize_torch_tensor, @@ -162,9 +161,8 @@ class InferenceSession: An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int): + def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): self._sequence_manager = sequence_manager - self._p2p = p2p self._closed = False self._chosen_spans = [] self._server_sessions = [] @@ -181,7 +179,7 @@ class InferenceSession: server_sessions = [] try: for span in chosen_spans: - stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id) span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) session = RemoteExpertWorker.run_coroutine( @@ -189,7 +187,7 @@ class InferenceSession: stub, span_uids, rpc_info=self._sequence_manager.rpc_info, - timeout=self._sequence_manager.request_timeout, + timeout=self._sequence_manager.config.request_timeout, max_length=self._max_length, **metadata, ) @@ -305,9 +303,8 @@ class InferenceSession: self._sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None: - self._sequence_manager.on_request_failure(span.peer_id) - if attempt_no + 1 == self._sequence_manager.max_retries: + self._sequence_manager.on_request_failure(span.peer_id if span is not None else None) + if attempt_no + 1 == self._sequence_manager.config.max_retries: raise delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index d67d4bf..0d218d1 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -18,13 +18,14 @@ from transformers.models.bloom import ( from petals.bloom.modeling_utils import LMHead from petals.client.remote_generation import RemoteGenerationMixin from petals.client.remote_sequential import RemoteSequential +from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.constants import PUBLIC_INITIAL_PEERS from petals.utils.misc import DUMMY logger = get_logger(__name__) -class DistributedBloomConfig(BloomConfig): +class DistributedBloomConfig(BloomConfig, SequenceManagerConfig): """ A bloom config that contains information about DHT peers. To create a distributed model, one must provide dht_prefix and either initial_peers or dht. @@ -33,15 +34,9 @@ class DistributedBloomConfig(BloomConfig): initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers - dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models - request_timeout: int = 3 * 60 # a number of seconds for waiting result from each node - max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) - allowed_servers: Optional[ - Collection[Union[str, hivemind.PeerID]] - ] = None # if defined, send requests only to these servers pre_seq_len: int = 0 # a number of tokens for prompt tuning. - tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters'] + tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"] # This settings matter for running the client with dtype bfloat16 on CPU. # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. @@ -106,30 +101,16 @@ class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel): config_class = DistributedBloomConfig - def __init__(self, config: DistributedBloomConfig): + def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None): assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." - assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)" + assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization super().__init__(config) assert len(self.h) == 0 config.n_layer = n_layer - dht = config.dht - if dht is None: - dht = hivemind.DHT( - initial_peers=config.initial_peers, - client_mode=True, - num_workers=n_layer, - startup_timeout=config.daemon_startup_timeout, - start=True, - ) - assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" - self.h = RemoteSequential( - config, - dht, - config.dht_prefix, - ) + self.h = RemoteSequential(config, dht=dht) # Forbid accumulate grads for embeddings and layernorm self.set_requires_grad(False) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 788805d..8bc60ff 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Optional, Union import torch -from hivemind import DHT, P2P, get_logger +from hivemind import DHT, get_logger from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from torch import nn @@ -25,39 +25,26 @@ class RemoteSequential(nn.Module): def __init__( self, config: petals.client.DistributedBloomConfig, - dht: DHT, - dht_prefix: Optional[str] = None, - p2p: Optional[P2P] = None, + *, sequence_manager: Optional[RemoteSequenceManager] = None, - **kwargs, + dht: Optional[DHT] = None, + start_block: Optional[int] = None, + end_block: Optional[int] = None, ): super().__init__() self.config = config - self.dht = dht - self.dht_prefix = dht_prefix or config.dht_prefix - self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p - num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager) - block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)) + assert sequence_manager is None or ( + dht is None and start_block is None and end_block is None + ), "`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`" if sequence_manager is None: - logger.debug(f"Creating new sequence manager for block uids: {block_uids}") - self.sequence_manager = RemoteSequenceManager( - dht, - block_uids, - self.p2p, - request_timeout=config.request_timeout, - max_retries=config.max_retries, - allowed_servers=config.allowed_servers, - **kwargs, - ) - self.is_subsequence = False - else: - logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules") - if kwargs: - logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided") - self.sequence_manager = sequence_manager - assert isinstance(sequence_manager.sequence_info.block_uids, tuple) - self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids + if start_block is None: + start_block = 0 + if end_block is None: + end_block = self.config.n_layer + block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) + sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) + self.sequence_manager = sequence_manager def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" @@ -66,23 +53,10 @@ class RemoteSequential(nn.Module): return outputs def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential: - assert isinstance(ix, (int, slice)) - if isinstance(ix, int): - return RemoteTransformerBlock( - self.config, - self.dht, - dht_prefix=self.dht_prefix, - p2p=self.p2p, - sequence_manager=self.sequence_manager[ix], - ) - else: - return RemoteSequential( - self.config, - self.dht, - dht_prefix=self.dht_prefix, - p2p=self.p2p, - sequence_manager=self.sequence_manager[ix], - ) + return RemoteSequential( + self.config, + sequence_manager=self.sequence_manager[ix], + ) def __iter__(self): for block_index in range(len(self)): @@ -92,22 +66,7 @@ class RemoteSequential(nn.Module): return len(self.sequence_manager) def inference_session(self, **kwargs) -> InferenceSession: - return InferenceSession(self.sequence_manager, self.p2p, **kwargs) + return InferenceSession(self.sequence_manager, **kwargs) def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" - - -class RemoteTransformerBlock(RemoteSequential): - """Single transformer block hosted by swarm - - This class is deprecated and kept for backward compatibility. - It will be removed soon in favor of using ``RemoteSequential`` directly. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert len(self) == 1, "Remote Block is a sequence size 1" - - def extra_repr(self): - return f"{self.sequence_manager.block_uids[0]}" diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index de7eb37..8dafb6e 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -27,14 +27,14 @@ class RemoteSequenceInfo: block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated spans_by_priority: List[RemoteSpanInfo] spans_containing_block: Tuple[List[RemoteSpanInfo], ...] - last_updated_time: float + last_updated_time: Optional[float] @classmethod def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T: block_uids = tuple(block_uids) empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) empty_spans = tuple([] for _ in range(len(block_uids))) - return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf")) + return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None) def __getitem__(self, ix: slice): assert isinstance(ix, slice) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 25c68ef..5f387c4 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import dataclasses import itertools import logging import random @@ -13,7 +14,6 @@ import numpy as np from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time from hivemind.dht.node import Blacklist from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2PHandlerError from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger @@ -26,6 +26,33 @@ from petals.server.handler import TransformerConnectionHandler logger = get_logger(__name__) +@dataclasses.dataclass +class SequenceManagerConfig: + allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers + + request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests + update_period: float = 60 # refresh DHT information once in this many seconds + + max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) + max_backoff: float = 60 # limit maximal sleep time between retries to this value + ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds + + +@dataclasses.dataclass +class SequenceManagerState: + p2p: P2P = None + sequence_info: Optional[RemoteSequenceInfo] = None + rpc_info: Optional[dict] = None + banned_peers: Optional[Blacklist] = None + + def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState: + return dataclasses.replace(self, sequence_info=self.sequence_info[ix]) + + def __len__(self) -> int: + return len(self.sequence_info) + + class RemoteSequenceManager: """ Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks. @@ -34,67 +61,56 @@ class RemoteSequenceManager: Using this information, sequence manager can form sequences of servers that collectively have the full sequence. To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr). - :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks - :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers - :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p()) - :param update_period: by default, refresh DHT information once in this many seconds - :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests - :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1) - :param max_backoff: limit maximal sleep time between retries to this value - :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds - :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht - :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time - :param allowed_servers: if defined, send requests only to these servers - :param start: start the background thread (see the note below). If false, you will need to start it manually. :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid running redundant sequence managers for the same set of layers. - """ def __init__( self, - dht: DHT, + config: SequenceManagerConfig, block_uids: Sequence[ModuleUID], - p2p: P2P, - update_period: float = 30, - request_timeout: float = 30, - max_retries: Optional[int] = None, - min_backoff: float = 1, - max_backoff: float = 15 * 60, - ban_timeout: float = 15, - sequence_info: Optional[RemoteSequenceInfo] = None, - rpc_info: Optional[dict] = None, - allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None, - banned_peers: Optional[Blacklist] = None, - # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below) + *, + dht: Optional[DHT] = None, + state: Optional[SequenceManagerState] = None, ): assert len(block_uids) > 0, "Sequences must contain at least one block" - self.dht, self.p2p = dht, p2p - self.request_timeout, self.max_retries = request_timeout, max_retries - self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff + + self.config = config + if state is None: + state = SequenceManagerState() + self.state = state + + if dht is None: + dht = DHT( + initial_peers=config.initial_peers, + client_mode=True, + num_workers=config.n_layer, + startup_timeout=config.daemon_startup_timeout, + start=True, + ) + assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance" + self.dht = dht + + if state.p2p is None: + state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) + self.lock_changes = threading.Lock() - self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update)) + self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update)) self._thread_start_lock = threading.Lock() self.policy = NoSpendingPolicy() - self._rpc_info = rpc_info - if allowed_servers is not None: - allowed_servers = { - PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers - } - self.allowed_servers = allowed_servers - self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers - - if sequence_info is None: - self.sequence_info = RemoteSequenceInfo.make_empty(block_uids) + if state.banned_peers is None: + state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0) + if state.sequence_info is None: + state.sequence_info = RemoteSequenceInfo.make_empty(block_uids) + if state.sequence_info.last_updated_time is None: # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records # in the first _update() instead of the latest ones. This makes the first .update() faster. petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True) self._need_latest_infos = False else: - self.sequence_info = sequence_info - assert block_uids == sequence_info.block_uids + assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch self._need_latest_infos = True @@ -118,7 +134,7 @@ class RemoteSequenceManager: span_sequence = [] current_index = start_index while current_index < end_index: - candidate_spans = self.sequence_info.spans_containing_block[current_index] + candidate_spans = self.state.sequence_info.spans_containing_block[current_index] if not candidate_spans: raise MissingBlocksError(current_index) if mode == "random": @@ -143,86 +159,62 @@ class RemoteSequenceManager: assert isinstance(ix, (int, slice)) if not isinstance(ix, slice): ix = slice(int(ix), int(ix) + 1, 1) - return type(self)( - self.dht, - self.block_uids[ix], - self.p2p, - update_period=self._thread.update_period, - request_timeout=self.request_timeout, - ban_timeout=self.ban_timeout, - min_backoff=self.min_backoff, - max_backoff=self.max_backoff, - sequence_info=self.sequence_info[ix], - rpc_info=self._rpc_info, - allowed_servers=self.allowed_servers, - banned_peers=self.banned_peers, - ) + return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix]) def update(self, *, wait: bool): """Run an asynchronous update in background as soon as possible""" - self.ready.clear() # TODO this should be a separate event + self.ready.clear() self._thread.trigger.set() if wait: self.ready.wait() def _update(self): """Perform an immediate and synchronous refresh, may take time""" - for attempt_no in itertools.count(): - try: - new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, latest=self._need_latest_infos - ) - self._need_latest_infos = True # All future _update() should use latest infos - - for block_info in new_block_infos: - if not block_info: - continue - - # Apply whitelist, if defined - if self.allowed_servers is not None: - block_info.servers = { - peer_id: server_info - for peer_id, server_info in block_info.servers.items() - if peer_id in self.allowed_servers - } - - # Remove temporarily banned peers, unless there are no peers left - valid_servers = { - peer_id: server_info - for peer_id, server_info in block_info.servers.items() - if peer_id not in self.banned_peers - } - if len(valid_servers) < len(block_info.servers): - if valid_servers: - logger.debug( - f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}" - ) - block_info.servers = valid_servers - else: - # If we blacklisted all servers, the error may actually be client-caused - logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist") - - with self.lock_changes: - self.sequence_info.update_(new_block_infos) - missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]] - if missing_blocks: - raise MissingBlocksError(missing_blocks) - self.ready.set() # if there is an active server for every block, we may begin running - break + new_block_infos = petals.dht_utils.get_remote_module_infos( + self.dht, self.block_uids, latest=self._need_latest_infos + ) + self._need_latest_infos = True # All future _update() should use latest infos + + for block_info in new_block_infos: + if not block_info: + continue + + # Apply whitelist, if defined + if self.config.allowed_servers is not None: + block_info.servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers + } + + # Remove temporarily banned peers, unless there are no peers left + valid_servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if peer_id not in self.state.banned_peers + } + if len(valid_servers) < len(block_info.servers): + if valid_servers: + logger.debug( + f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}" + ) + block_info.servers = valid_servers + else: + # If we blacklisted all servers, the error may actually be client-caused + logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist") - except Exception as e: - delay = self.get_retry_delay(attempt_no) - logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)") - maybe_log_traceback(e) - time.sleep(delay) + with self.lock_changes: + self.state.sequence_info.update_(new_block_infos) + self.ready.set() - def on_request_failure(self, peer_id: PeerID): + def on_request_failure(self, peer_id: Optional[PeerID]): """remove a given peer from the routing table. If the routing is no longer possible, trigger an update""" - logger.info(f"Peer {peer_id} did not respond, banning it temporarily") - self.banned_peers.register_failure(peer_id) + if peer_id is not None: + logger.debug(f"Peer {peer_id} did not respond, banning it temporarily") + self.state.banned_peers.register_failure(peer_id) with self.lock_changes: should_update = False - for info in self.sequence_info.block_infos: + for info in self.state.sequence_info.block_infos: info.servers.pop(peer_id, None) if not info.servers: should_update = True @@ -232,7 +224,7 @@ class RemoteSequenceManager: def on_request_success(self, peer_id: PeerID): """if peer has a failure streak, clear that streak""" - self.banned_peers.register_success(peer_id) + self.state.banned_peers.register_success(peer_id) def __len__(self): return len(self.block_uids) @@ -247,57 +239,58 @@ class RemoteSequenceManager: @property def block_uids(self): - return self.sequence_info.block_uids + return self.state.sequence_info.block_uids @property def rpc_info(self): """Return the rpc_info queried from one of the servers that hold the first block""" - if self._rpc_info is None: - with self._thread_start_lock: - if not self.is_alive(): - self._thread.start() - - for attempt_no in itertools.count(): - peer_id = None - try: - if not self.ready.is_set(): - self.update(wait=True) - - active_servers = [ - peer_id - for peer_id, server in self.sequence_info.block_infos[0].servers.items() - if server.state == ServerState.ONLINE - ] - if not active_servers: - raise MissingBlocksError(0) - peer_id = random.choice(active_servers) - - stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id) - outputs = RemoteExpertWorker.run_coroutine( - stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0])) - ) - self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info) - self.on_request_success(peer_id) - break - except Exception as e: - if peer_id is not None and not isinstance(e, P2PHandlerError): - self.on_request_failure(peer_id) - if attempt_no + 1 == self.max_retries: - raise - delay = self.get_retry_delay(attempt_no) - logger.warning( - f"Caught exception when gathering information from peer {peer_id} " - f"(retry in {delay:.0f} sec): {repr(e)}" - ) - maybe_log_traceback(e) - time.sleep(delay) + if self.state.rpc_info is not None: + return self.state.rpc_info + + with self._thread_start_lock: + if not self.is_alive(): + self._thread.start() + + for attempt_no in itertools.count(): + peer_id = None + try: + if not self.ready.is_set(): + self.update(wait=True) + + active_servers = [ + peer_id + for peer_id, server in self.state.sequence_info.block_infos[0].servers.items() + if server.state == ServerState.ONLINE + ] + if not active_servers: + raise MissingBlocksError(0) + peer_id = random.choice(active_servers) + + stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id) + outputs = RemoteExpertWorker.run_coroutine( + stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout) + ) + self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info) + self.on_request_success(peer_id) + break + except Exception as e: + self.on_request_failure(peer_id) + if attempt_no + 1 == self.config.max_retries: + raise + delay = self.get_retry_delay(attempt_no) + logger.warning( + f"Caught exception when gathering information from peer {peer_id} " + f"(retry in {delay:.0f} sec): {repr(e)}" + ) + maybe_log_traceback(e) + time.sleep(delay) - return self._rpc_info + return self.state.rpc_info def get_retry_delay(self, attempt_no: int) -> float: if attempt_no == 0: return 0 - return min(self.min_backoff * 2 ** (attempt_no - 1), self.max_backoff) + return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff) def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]: """ diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index b846dfc..166b93c 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -67,7 +67,7 @@ async def sequential_forward( span = sequences.popleft() - stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) inputs_and_prompts = [inputs, prompts[span.start : span.end]] span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) @@ -77,7 +77,7 @@ async def sequential_forward( stub, sequence_manager.rpc_info, *inputs_and_prompts, - timeout=sequence_manager.request_timeout, + timeout=sequence_manager.config.request_timeout, metadata=MSGPackSerializer.dumps(metadata), ) @@ -93,9 +93,8 @@ async def sequential_forward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None: - sequence_manager.on_request_failure(span.peer_id) - if attempt_no + 1 == sequence_manager.max_retries: + sequence_manager.on_request_failure(span.peer_id if span is not None else None) + if attempt_no + 1 == sequence_manager.config.max_retries: raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( @@ -152,7 +151,7 @@ async def sequential_backward( span = forward_sequences.pop() span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) - stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) metadata = sequence_manager.get_request_metadata( "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id ) @@ -163,7 +162,7 @@ async def sequential_backward( inputs, grad_outputs, prompts[span.start : span.end], - timeout=sequence_manager.request_timeout, + timeout=sequence_manager.config.request_timeout, metadata=MSGPackSerializer.dumps(metadata), ) grad_outputs = [grad_outputs] @@ -171,9 +170,8 @@ async def sequential_backward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None: - sequence_manager.on_request_failure(span.peer_id) - if attempt_no + 1 == sequence_manager.max_retries: + sequence_manager.on_request_failure(span.peer_id if span is not None else None) + if attempt_no + 1 == sequence_manager.config.max_retries: raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 06c30eb..69cd64f 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -71,67 +71,6 @@ async def _declare_active_modules( ) -def get_remote_sequence( - dht: DHT, - start: int, - stop: int, - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, - return_future: bool = False, -) -> Union[petals.client.RemoteSequential, MPFuture]: - return RemoteExpertWorker.run_coroutine( - _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future - ) - - -async def _get_remote_sequence( - dht: DHT, - start: int, - stop: int, - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, -) -> petals.client.RemoteSequential: - uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)] - p2p = await dht.replicate_p2p() - manager = petals.client.RemoteSequenceManager(dht, uids, p2p) - return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager) - - -def get_remote_module( - dht: DHT, - uid_or_uids: Union[ModuleUID, List[ModuleUID]], - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, - return_future: bool = False, -) -> Union[Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]], MPFuture]: - """ - :param uid_or_uids: find one or more modules with these ids from across the DHT - :param config: model config, usually taken by .from_pretrained(MODEL_NAME) - :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background. - :returns: a list of [RemoteTransformerBlock] - """ - return RemoteExpertWorker.run_coroutine( - _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future - ) - - -async def _get_remote_module( - dht: DHT, - uid_or_uids: Union[ModuleUID, List[ModuleUID]], - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, -) -> Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]]: - single_uid = isinstance(uid_or_uids, ModuleUID) - uids = [uid_or_uids] if single_uid else uid_or_uids - p2p = await dht.replicate_p2p() - managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids) - modules = [ - petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) - for m in managers - ] - return modules[0] if single_uid else modules - - def get_remote_module_infos( dht: DHT, uids: Sequence[ModuleUID], diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index d2fbdde..4cddfed 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -1,28 +1,24 @@ import random from typing import Union -import hivemind import pytest import torch from transformers.models.bloom.configuration_bloom import BloomConfig from petals.bloom.block import WrappedBloomBlock from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block -from petals.client import DistributedBloomConfig -from petals.client.remote_sequential import RemoteTransformerBlock +from petals.client import DistributedBloomConfig, RemoteSequential from petals.data_structures import UID_DELIMITER -from petals.dht_utils import get_remote_module from test_utils import * @pytest.mark.forked def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + remote_sequential = RemoteSequential(config) for block_index in random.sample(range(config.n_layer), 3): - remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config) - assert isinstance(remote_block, RemoteTransformerBlock) + remote_block = remote_sequential[block_index] inputs = torch.randn(1, 8, config.hidden_size) outputs_forward = remote_block(inputs) @@ -36,7 +32,6 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: sess.step(inputs[:, -1:, :]) assert "Maximum length exceeded" in repr(exc_info.value) - outputs_inference = torch.cat(outputs_inference, dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 9a619b7..15f3b5c 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -4,22 +4,19 @@ # - if you want to figure out chained inference, ask yozh -import hivemind import pytest import torch from petals.bloom.from_pretrained import load_pretrained_block from petals.client import DistributedBloomConfig from petals.client.remote_sequential import RemoteSequential -from petals.dht_utils import get_remote_sequence from test_utils import * @pytest.mark.forked def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - remote_blocks = get_remote_sequence(dht, 3, 6, config) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + remote_blocks = RemoteSequential(config, start_block=3, end_block=6) assert isinstance(remote_blocks, RemoteSequential) ref_blocks = [ @@ -46,10 +43,8 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq @pytest.mark.forked def test_chained_inference_exact_match(atol_inference=1e-4): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - remote_blocks = get_remote_sequence(dht, 3, 5, config) - assert isinstance(remote_blocks, RemoteSequential) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + remote_blocks = RemoteSequential(config, start_block=3, end_block=5) inputs = torch.randn(1, 8, config.hidden_size) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index a69c42f..d46ca1c 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -20,7 +20,7 @@ def test_remote_sequential(): test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True) grad_proj = torch.randn(1, 5, config.hidden_size) - sequential = RemoteSequential(config, dht) + sequential = RemoteSequential(config, dht=dht) full_outputs = sequential(test_inputs) (full_outputs * grad_proj).sum().backward() @@ -48,7 +48,7 @@ def test_remote_sequential(): # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] lossy_sequential = RemoteSequential( - config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p) + config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht) ) test_inputs.grad = None @@ -85,8 +85,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager): @pytest.mark.forked def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) - dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) - remote_sequential = RemoteSequential(config, dht) + remote_sequential = RemoteSequential(config) inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1) diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index f0b61cf..7dbc82f 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -18,15 +18,14 @@ logger = get_logger(__name__) def test_sequence_manager_basics(mode: str): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) - sequential = RemoteSequential(config, dht) + sequential = RemoteSequential(config, dht=dht) shutdown_evt = threading.Event() # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] sequential = RemoteSequential( config, - dht, - sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt), + sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt), ) sequence = sequential.sequence_manager.make_sequence(mode=mode) diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 54d6d33..0010167 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -4,34 +4,33 @@ import hivemind import pytest import torch -from petals.client import DistributedBloomConfig +from petals.client import DistributedBloomConfig, RemoteSequential from petals.data_structures import UID_DELIMITER -from petals.dht_utils import get_remote_sequence from petals.server.handler import CACHE_TOKENS_AVAILABLE from test_utils import * @pytest.mark.forked def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to) + blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to) - blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") - blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") info_before = blocks1.sequence_manager.rpc_info with blocks1.inference_session(max_length=max_length) as sess: sess.step(torch.randn(1, 1, config.hidden_size)) - blocks1.sequence_manager._rpc_info = None # invalidate cache + blocks1.sequence_manager.state.rpc_info = None # invalidate cache info_inside = blocks1.sequence_manager.rpc_info with blocks2.inference_session(max_length=max_length2) as sess2: sess2.step(torch.randn(1, 1, config.hidden_size)) - blocks2.sequence_manager._rpc_info = None # invalidate cache + blocks2.sequence_manager.state.rpc_info = None # invalidate cache info_inside2 = blocks2.sequence_manager.rpc_info time.sleep(0.1) - blocks1.sequence_manager._rpc_info = None # invalidate cache + blocks1.sequence_manager.state.rpc_info = None # invalidate cache info_after = blocks1.sequence_manager.rpc_info assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE] From 0a313bf6c5d82b57103b973f6c851e5186f91cb1 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 7 May 2023 14:57:05 +0400 Subject: [PATCH 081/168] Update hivemind to 1.1.8, enable efficient bfloat16 encoding (#311) This PR: 1. Updates hivemind to 1.1.8 (includes https://github.com/learning-at-home/hivemind/pull/565) 2. Enables efficient bfloat16 serialization by default (`USE_LEGACY_BFLOAT16 = False`) 3. Removes logging code that was included to hivemind in https://github.com/learning-at-home/hivemind/pull/542 --- setup.cfg | 2 +- src/petals/__init__.py | 11 +++++++++++ src/petals/utils/logging.py | 18 ------------------ 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/setup.cfg b/setup.cfg index 786c8f5..8c237aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = huggingface-hub>=0.11.1,<1.0.0 transformers>=4.25.1,<5.0.0 speedtest-cli==2.1.3 - hivemind==1.1.7 + hivemind==1.1.8 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 373c40a..7a39b49 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,6 +1,17 @@ +import os + +import hivemind + from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs __version__ = "1.1.4" + +def _override_bfloat16_mode_default(): + if os.getenv("USE_LEGACY_BFLOAT16") is None: + hivemind.compression.base.USE_LEGACY_BFLOAT16 = False + + _initialize_logs() +_override_bfloat16_mode_default() diff --git a/src/petals/utils/logging.py b/src/petals/utils/logging.py index 6fe099f..0574fa0 100644 --- a/src/petals/utils/logging.py +++ b/src/petals/utils/logging.py @@ -4,16 +4,6 @@ import os from hivemind.utils import logging as hm_logging -def in_jupyter() -> bool: - """Check if the code is run in Jupyter or Colab""" - - try: - __IPYTHON__ - return True - except NameError: - return False - - def initialize_logs(): """Initialize Petals logging tweaks. This function is called when you import the `petals` module.""" @@ -21,14 +11,6 @@ def initialize_logs(): if os.getenv("PETALS_LOGGING", "True").lower() in ("false", "0"): return - if in_jupyter(): - os.environ["HIVEMIND_COLORS"] = "True" - importlib.reload(hm_logging) - - # Remove log handlers from previous import of hivemind.utils.logging and extra handlers on Colab - hm_logging.get_logger().handlers.clear() - hm_logging.get_logger("hivemind").handlers.clear() - hm_logging.use_hivemind_log_handler("in_root_logger") # We suppress asyncio error logs by default since they are mostly not relevant for the end user, From 6137b1b4b057458f593e21d08012df47e5c194f4 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 9 May 2023 22:38:20 +0400 Subject: [PATCH 082/168] Replace .make_sequence(..., mode="random") with mode="max_throughput" (#313) We need to sample the next server using its throughput as the weight to actually achieve max throughput for fine-tuning. As an example, imagine a situation where we have 3 servers with throughputs [1000, 500, 1] hosting the same blocks, then compare the uniform and weighted sampling strategies. --- src/petals/client/inference_session.py | 2 +- src/petals/client/routing/sequence_info.py | 6 ++++-- src/petals/client/routing/sequence_manager.py | 16 ++++++++-------- src/petals/client/sequential_autograd.py | 2 +- src/petals/data_structures.py | 7 ++++++- tests/test_sequence_manager.py | 2 +- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 93700f9..15de442 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -253,7 +253,7 @@ class InferenceSession: ) recovery_until = max(recovery_until, update_end) - updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="fastest") + updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency") # make_sequence() could return a longer sequence updated_spans[-1].end = min(updated_spans[-1].end, update_end) updated_sessions = self._enter_server_sessions(updated_spans) diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index 8dafb6e..b35b02b 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -77,7 +77,9 @@ class RemoteSequenceInfo: if server.state != ServerState.ONLINE: continue if peer_id not in active_spans: - active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) + active_spans[peer_id] = RemoteSpanInfo( + peer_id=peer_id, start=block_index, end=block_index + 1, throughput=server.throughput + ) else: # peer_id in active_spans active_spans[peer_id].end = block_index + 1 @@ -91,7 +93,7 @@ class RemoteSequenceInfo: closed_spans.append(active_spans.pop(peer_id)) assert not active_spans, f"spans: {active_spans}" - closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) + closed_spans.sort(key=lambda span: span.length, reverse=True) spans_containing_block = tuple(list() for _ in range(len(block_infos))) for span in closed_spans: diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 5f387c4..8ce33f9 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -115,14 +115,14 @@ class RemoteSequenceManager: self._need_latest_infos = True def make_sequence( - self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random" + self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str ) -> List[RemoteSpanInfo]: """ Form a sequence of remote servers that collectively serve all consecutive layers :param start_index: optional index of the first module in a sequence, default = the first of block_uids :param end_index: optional index of the last module (non-inclusive), default = after last of block uids - :param mode: either random or fastest + :param mode: one of ["max_throughput", "min_latency"] """ with self._thread_start_lock: if not self.is_alive(): @@ -137,17 +137,17 @@ class RemoteSequenceManager: candidate_spans = self.state.sequence_info.spans_containing_block[current_index] if not candidate_spans: raise MissingBlocksError(current_index) - if mode == "random": - chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing - elif mode == "fastest": - # note: this too is a heuristic that will be replaced once we integrate fastest wall time routing + + if mode == "max_throughput": + span_weights = np.array([span.throughput for span in candidate_spans], dtype=np.float64) + elif mode == "min_latency": span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64) - chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) else: raise RuntimeError(f"Unexpected mode {mode}") + chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end - span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id)) + span_sequence.append(dataclasses.replace(chosen_span, start=current_index)) current_index = chosen_span.end route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 166b93c..1c66a49 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -60,7 +60,7 @@ async def sequential_forward( span = None try: if not sequences or attempt_no >= 1: - sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="random")) + sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="max_throughput")) # make_sequence() could return a longer sequence sequences[-1].end = min(sequences[-1].end, end_index) logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers") diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 5d85f07..80b8f62 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -39,9 +39,14 @@ class RemoteModuleInfo: class RemoteSpanInfo: """A chain of remote blocks served by one specific remote peer""" + peer_id: PeerID start: int end: int - peer_id: PeerID + throughput: float + + @property + def length(self): + return self.end - self.start RPCInfo = Dict[str, Any] diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 7dbc82f..38e9a8a 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -14,7 +14,7 @@ logger = get_logger(__name__) @pytest.mark.forked -@pytest.mark.parametrize("mode", ["fastest", "random"]) +@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"]) def test_sequence_manager_basics(mode: str): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) From d9e7bfc949c3a55d94f5e68a9bc677145edf0547 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 9 May 2023 23:23:08 +0400 Subject: [PATCH 083/168] Divide compute throughput by average no. of used blocks (#314) See #192. --- src/petals/server/server.py | 5 ++-- src/petals/server/throughput.py | 43 +++++++++++++++++++++------------ 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 4f2a645..470ac5f 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -27,7 +27,7 @@ from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability -from petals.server.throughput import get_dtype_name, get_host_throughput +from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -193,10 +193,11 @@ class Server: assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: - throughput = get_host_throughput( + throughput = get_server_throughput( self.block_config, device, torch_dtype, + num_blocks=num_blocks, load_in_8bit=load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, force_eval=(throughput == "eval"), diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index ac43759..a60a24d 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -1,11 +1,12 @@ import fcntl import json +import math import os import time from collections import Counter from hashlib import sha256 from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Union import torch from hivemind.utils.logging import get_logger @@ -32,11 +33,12 @@ if not hasattr(speedtest, "Speedtest"): ) -def get_host_throughput( +def get_server_throughput( config: BloomConfig, device: torch.device, dtype: Union[str, torch.dtype], *, + num_blocks: int, load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, @@ -47,7 +49,7 @@ def get_host_throughput( if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, "throughput.lock") - cache_path = Path(cache_dir, "throughput_v2.json") + cache_path = Path(cache_dir, "throughput_v3.json") # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) @@ -85,7 +87,16 @@ def get_host_throughput( except Exception: logger.exception(f"Failed to save throughput info in {cache_path}") - return cache[cache_key] + throughput_info = cache[cache_key] + + # Most requests start at some block hosted by a server, then use all next blocks hosted on this server. + # Assuming the start block index is distributed uniformly, the average number of blocks used per request is + # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2 + average_blocks_used = (num_blocks + 1) / 2 + throughput = throughput_info["compute_rps"] / average_blocks_used + throughput = min(throughput, throughput_info.get("network_rps", math.inf)) + logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks") + return throughput def measure_throughput_info( @@ -95,22 +106,24 @@ def measure_throughput_info( *, load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], -) -> float: +) -> Dict[str, float]: """Measure network and compute throughput in forward pass tokens per second""" logger.info( "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" ) - result = measure_compute_rps( - config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices - ) + throughput_info = { + "compute_rps": measure_compute_rps( + config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + ) + } try: - result = min(result, measure_network_rps(config)) + throughput_info["network_rps"] = measure_network_rps(config) except Exception: logger.warning("Failed to measure network throughput:", exc_info=True) logger.warning("Proceeding with the compute throughput only") - return result + return throughput_info def measure_network_rps(config: BloomConfig) -> Optional[float]: @@ -127,10 +140,9 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]: raise ValueError("speedtest has returned network_rps == 0") logger.info( - f"Network throughput: " - f"{network_info['download'] / 1e6:.2f} Mbit/s on download, " - f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, " - f"{network_rps:.1f} RPS" + f"Network throughput: {network_rps:.1f} RPS " + f"({network_info['download'] / 1e6:.2f} Mbit/s on download, " + f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)" ) return network_rps @@ -168,7 +180,8 @@ def measure_compute_rps( devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) logger.info( - f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS" + f"Forward pass throughput: {device_rps:.1f} RPS per block " + f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})" ) return device_rps From 6eb306a60571c8e6705d8ba2523971681802f2fd Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 9 May 2023 23:51:57 +0400 Subject: [PATCH 084/168] Raise error for unexpected .generate() kwargs (#315) Now, if a user passes unexpected kwargs to `.generate()`, they are __ignored__ and the code continues working as if the argument was correctly supported. For example, people often tried passing `repetition_penalty` and didn't notice that it does not have any effect. This PR fixes this problem. --- src/petals/client/remote_generation.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index bcf59ab..9c2b51f 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -44,6 +44,7 @@ class RemoteGenerationMixin: def generate( self, inputs: Optional[torch.Tensor] = None, + *, do_sample: Optional[bool] = None, temperature: float = 1.0, top_k: Optional[int] = None, @@ -57,9 +58,7 @@ class RemoteGenerationMixin: decoding_algorithm: Optional[DecodingAlgorithm] = None, provided_constraints: List[ABCBloomConstraint] = [], num_return_sequences: Optional[int] = None, - *, session: Optional[InferenceSession] = None, - **model_kwargs, ) -> torch.LongTensor: """ Generates sequences of token ids for models with a language modeling head. @@ -77,19 +76,9 @@ class RemoteGenerationMixin: :param max_new_tokens: The maximum number of tokens to generate. :param decoding_algorithm: The decoding algorithm to use. :param provided_constraints: A list of constraints to use. - :param model_kwargs: Additional arguments to pass to the model. :param num_return_sequences: How many hypothesis from the beam will be in output. """ - assert ( - model_kwargs.get("logits_processor", None) is None - ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor" - assert ( - model_kwargs.get("logits_wrapper", None) is None - ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper" - assert ( - model_kwargs.get("stopping_criteria", None) is None - ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria" prefix_length = 0 if inputs is None else inputs.size(1) prefix_length += self.config.pre_seq_len @@ -226,7 +215,6 @@ class RemoteGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, provided_constraints: List[ABCBloomConstraint] = [], - **model_kwargs, ) -> torch.LongTensor: """ Generates sequences of token ids for models with a language modeling head. Uses greedy search. @@ -244,7 +232,6 @@ class RemoteGenerationMixin: eos_token_id=eos_token_id, decoding_algorithm=GreedyAlgorithm(), provided_constraints=provided_constraints, - **model_kwargs, ) def sample( @@ -257,7 +244,6 @@ class RemoteGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, provided_constraints: List[ABCBloomConstraint] = [], - **model_kwargs, ) -> torch.LongTensor: """ Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling. @@ -271,7 +257,6 @@ class RemoteGenerationMixin: :param: pad_token_id: The id of the padding token. :param: eos_token_id: The id of the end of sentence token. :param: provided_constraints: A list of constraints to use. - :param: model_kwargs: Additional kwargs to pass to the model. """ return self.generate( @@ -281,7 +266,6 @@ class RemoteGenerationMixin: eos_token_id=eos_token_id, decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p), provided_constraints=provided_constraints, - **model_kwargs, ) def beam_search( @@ -292,7 +276,6 @@ class RemoteGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, provided_constraints: List[ABCBloomConstraint] = [], - **model_kwargs, ) -> torch.LongTensor: """ Generates sequences of token ids for models with a language modeling head. Uses beam search. @@ -303,7 +286,6 @@ class RemoteGenerationMixin: :param pad_token_id: The id of the padding token. :param eos_token_id: The id of the end of sentence token. :param provided_constraints: A list of constraints to use. - :param: model_kwargs: Additional kwargs to pass to the model. """ decoding_algorithm = BeamSearchAlgorithm( num_beams=num_beams, @@ -317,7 +299,6 @@ class RemoteGenerationMixin: eos_token_id=eos_token_id, decoding_algorithm=decoding_algorithm, provided_constraints=provided_constraints, - **model_kwargs, ) def beam_sample( @@ -327,7 +308,6 @@ class RemoteGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, provided_constraints: List[ABCBloomConstraint] = [], - **model_kwargs, ) -> torch.LongTensor: raise NotImplementedError @@ -338,7 +318,6 @@ class RemoteGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, provided_constraints: List[ABCBloomConstraint] = [], - **model_kwargs, ) -> torch.LongTensor: raise NotImplementedError From e02695233877627dad7237c6dc2727f8d289e632 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 10 May 2023 01:53:31 +0400 Subject: [PATCH 085/168] Abort speedtest if it runs too long (#316) Addresses #192 and, specifically, #280. --- src/petals/server/throughput.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index a60a24d..dbefb35 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -1,6 +1,7 @@ import fcntl import json import math +import multiprocessing as mp import os import time from collections import Counter @@ -120,24 +121,26 @@ def measure_throughput_info( } try: throughput_info["network_rps"] = measure_network_rps(config) - except Exception: - logger.warning("Failed to measure network throughput:", exc_info=True) + except Exception as e: + logger.warning(f"Failed to measure network throughput: {repr(e)}") logger.warning("Proceeding with the compute throughput only") return throughput_info -def measure_network_rps(config: BloomConfig) -> Optional[float]: - s = speedtest.Speedtest() - s.get_servers() - s.get_best_server() - s.download() - s.upload() - network_info = s.results.dict() +def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]: + pipe_recv, pipe_send = mp.Pipe(duplex=False) + process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) + process.start() + + if not pipe_recv.poll(timeout): + process.terminate() + raise RuntimeError(f"speedtest did not finish in {timeout} seconds") + network_info = pipe_recv.recv() bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request if network_rps == 0: - raise ValueError("speedtest has returned network_rps == 0") + raise RuntimeError("speedtest has returned network_rps == 0") logger.info( f"Network throughput: {network_rps:.1f} RPS " @@ -147,6 +150,15 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]: return network_rps +def _measure_bits_per_second(pipe_send: mp.Pipe): + s = speedtest.Speedtest() + s.get_servers() + s.get_best_server() + s.download() + s.upload() + pipe_send.send(s.results.dict()) + + def measure_compute_rps( config: BloomConfig, device: torch.device, From 675bacb592bac7145d38ded2ea746da2b9b6c391 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 10 May 2023 03:01:01 +0400 Subject: [PATCH 086/168] Bump version to 1.1.5 (#312) --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 7a39b49..b50b251 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -5,7 +5,7 @@ import hivemind from petals.client import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.1.4" +__version__ = "1.1.5" def _override_bfloat16_mode_default(): From 3e7ae5116de42cd3603647d0727c3324c50f8bd6 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 11 Jun 2023 00:44:41 +0300 Subject: [PATCH 087/168] Remove unused imports and attributes (#324) * Remove unused imports and attributes --- src/petals/client/inference_session.py | 1 - src/petals/client/remote_model.py | 3 +-- src/petals/client/remote_sequential.py | 1 - src/petals/client/routing/sequence_manager.py | 2 +- src/petals/client/sequential_autograd.py | 1 - src/petals/dht_utils.py | 2 -- src/petals/server/backend.py | 2 +- src/petals/server/memory_cache.py | 4 ++-- src/petals/server/reachability.py | 1 - src/petals/server/server.py | 5 ----- src/petals/utils/generation_algorithms.py | 1 - src/petals/utils/logging.py | 1 - tests/test_block_exact_match.py | 1 - tests/test_server_stats.py | 1 - 14 files changed, 5 insertions(+), 21 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 15de442..168dd40 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio import itertools -import logging import time from typing import AsyncIterator, List, Optional diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 0d218d1..b556714 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -1,6 +1,5 @@ -import os from contextlib import contextmanager -from typing import Collection, List, Optional, Union +from typing import List, Optional, Union import hivemind import torch diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 8bc60ff..39811e3 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -4,7 +4,6 @@ from typing import Optional, Union import torch from hivemind import DHT, get_logger -from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from torch import nn import petals.client diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 8ce33f9..6ac7bb0 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -11,7 +11,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Union from weakref import WeakMethod import numpy as np -from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time +from hivemind import DHT, P2P, MSGPackSerializer, PeerID from hivemind.dht.node import Blacklist from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.proto import runtime_pb2 diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 1c66a49..425fdb7 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -3,7 +3,6 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s """ import asyncio import itertools -import logging from collections import deque from typing import List, Optional, Sequence, Tuple diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 69cd64f..177b2f6 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -8,11 +8,9 @@ from functools import partial from typing import Dict, List, Optional, Sequence, Union from hivemind.dht import DHT, DHTNode, DHTValue -from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger -import petals.client from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState logger = get_logger(__name__) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 4464e7c..aae181e 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -16,7 +16,7 @@ from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import BloomAttention from petals.data_structures import InferenceMetadata -from petals.server.memory_cache import Handle, MemoryCache +from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 7ea981f..7f00bae 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -10,7 +10,7 @@ import ctypes import multiprocessing as mp import os import time -from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple +from typing import AsyncContextManager, Dict, Optional, Sequence import hivemind import torch @@ -29,7 +29,7 @@ class MemoryCache: def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float): self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) self.alloc_timeout = alloc_timeout - self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event() + self._lock_metadata = mp.Lock() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._allocated_tensors: Dict[Handle, torch.Tensor] = {} diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index 58caa93..03e01fc 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -5,7 +5,6 @@ import time from concurrent.futures import Future from contextlib import asynccontextmanager from functools import partial -from secrets import token_hex from typing import Optional import requests diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 470ac5f..1055d27 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -8,7 +8,6 @@ import threading import time from typing import Dict, List, Optional, Sequence, Union -import numpy as np import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -502,7 +501,6 @@ class ModuleContainer(threading.Thread): expiration=expiration, daemon=True, ) - self.checkpoint_saver = None # no need to save checkpoints since we do not change model state if start: self.run_in_background(await_ready=True) @@ -517,9 +515,6 @@ class ModuleContainer(threading.Thread): self.online_announcer.start() - if self.checkpoint_saver is not None: - self.checkpoint_saver.start() - for handler in self.conn_handlers: handler.run_in_background() diff --git a/src/petals/utils/generation_algorithms.py b/src/petals/utils/generation_algorithms.py index 9033371..d58f073 100644 --- a/src/petals/utils/generation_algorithms.py +++ b/src/petals/utils/generation_algorithms.py @@ -85,7 +85,6 @@ class NucleusAlgorithm(SamplingAlgorithm): class BeamSearchAlgorithm(DecodingAlgorithm): def __init__(self, num_beams: int, batch_size: int) -> None: self.num_beams = num_beams - self._cur_num_beams = 1 self.batch_size = batch_size self._batch_beams = [list() for _ in range(batch_size)] diff --git a/src/petals/utils/logging.py b/src/petals/utils/logging.py index 0574fa0..919092c 100644 --- a/src/petals/utils/logging.py +++ b/src/petals/utils/logging.py @@ -1,4 +1,3 @@ -import importlib import os from hivemind.utils import logging as hm_logging diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 4cddfed..a05387d 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -8,7 +8,6 @@ from transformers.models.bloom.configuration_bloom import BloomConfig from petals.bloom.block import WrappedBloomBlock from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block from petals.client import DistributedBloomConfig, RemoteSequential -from petals.data_structures import UID_DELIMITER from test_utils import * diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 0010167..11d2565 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -5,7 +5,6 @@ import pytest import torch from petals.client import DistributedBloomConfig, RemoteSequential -from petals.data_structures import UID_DELIMITER from petals.server.handler import CACHE_TOKENS_AVAILABLE from test_utils import * From c839173e571de84032194db2c89edf23a34c0504 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 16 Jun 2023 11:52:51 +0200 Subject: [PATCH 088/168] Determine block dtype in a unified manner (#325) * Extract backend_dtype, remove duplicate DTYPE_MAP * Use bfloat16 as the default dtype, resolve dtype in load_pretrained_block --- src/petals/bloom/from_pretrained.py | 5 +++-- src/petals/cli/convert_model.py | 4 +--- src/petals/server/block_utils.py | 13 ++++++------- src/petals/server/server.py | 12 ++++++------ tests/test_dtype.py | 17 +++++++++++++++++ 5 files changed, 33 insertions(+), 18 deletions(-) create mode 100644 tests/test_dtype.py diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 4748b41..d40b01f 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -21,7 +21,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig from transformers.utils import get_file_from_repo from petals.bloom.block import WrappedBloomBlock -from petals.server.block_utils import get_block_size +from petals.server.block_utils import get_block_size, resolve_block_dtype from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for logger = get_logger(__name__) @@ -41,6 +41,7 @@ def load_pretrained_block( ) -> WrappedBloomBlock: """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it.""" assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + torch_dtype = resolve_block_dtype(config, torch_dtype) if config is None: config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) @@ -66,7 +67,7 @@ def load_pretrained_block( for param_name, _ in block.named_parameters(): assert param_name in state_dict, f"{param_name} not in state dict" param = state_dict[param_name] - if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): param = param.to(torch_dtype) set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index 95b08e4..15e12b7 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -10,13 +10,11 @@ from huggingface_hub import HfApi, Repository from tqdm.auto import tqdm from transformers.models.bloom.modeling_bloom import BloomModel -from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH +from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP from petals.client import DistributedBloomConfig logger = get_logger(__name__) -DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") - def main(): parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.") diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index fd39ad6..8d59d18 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -7,14 +7,13 @@ from transformers import BloomConfig from petals.bloom.block import WrappedBloomBlock -def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]: +def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" - - if dtype == "auto" or dtype is None: - dtype = config.torch_dtype - if dtype == "auto" or dtype is None: - dtype = torch.float32 - return dtype + if dtype not in ("auto", None): + return dtype + if config.torch_dtype not in ("auto", None): + return config.torch_dtype + return torch.bfloat16 def get_block_size( diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 1055d27..2d666ae 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -22,7 +22,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace -from petals.server.block_utils import get_block_size +from petals.server.block_utils import get_block_size, resolve_block_dtype from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability @@ -151,7 +151,7 @@ class Server: if isinstance(torch_dtype, str): torch_dtype = DTYPE_MAP[torch_dtype] assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - self.torch_dtype = torch_dtype + self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype) if tensor_parallel_devices is None: tensor_parallel_devices = (device,) @@ -182,6 +182,7 @@ class Server: if attn_cache_size is None: # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336 + self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB") @@ -404,22 +405,21 @@ class ModuleContainer(threading.Thread): ) block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) - backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype blocks[module_uid] = TransformerBackend( module_uid, block, config=block_config, memory_cache=memory_cache, - backend_dtype=backend_dtype, + backend_dtype=torch_dtype, args_schema=( BatchTensorDescriptor( - 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression + 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression ), ), kwargs_schema={}, outputs_schema=( BatchTensorDescriptor( - 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression + 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression ), ), min_batch_size=min_batch_size, diff --git a/tests/test_dtype.py b/tests/test_dtype.py new file mode 100644 index 0000000..03afd83 --- /dev/null +++ b/tests/test_dtype.py @@ -0,0 +1,17 @@ +import pytest +import torch + +from petals.bloom.from_pretrained import load_pretrained_block +from petals.client import DistributedBloomConfig +from petals.server.block_utils import resolve_block_dtype +from test_utils import MODEL_NAME + + +@pytest.mark.forked +@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"]) +def test_backend_dtype(torch_dtype): + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype) + backend_dtype = resolve_block_dtype(config, torch_dtype) + other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype + assert backend_dtype == other_backend_dtype From 5c0733711ab768eea2734a0234eb0e2fade324b7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 17 Jun 2023 13:14:31 +0200 Subject: [PATCH 089/168] Use number of tokens for attn_cache_size (#286) * Use number of tokens for attn_cache_size * Fix cache_bytes_per_block * Rename attn_cache_size to attn_cache_tokens --- .github/workflows/run-tests.yaml | 2 +- src/petals/cli/run_server.py | 20 +++++--------------- src/petals/server/backend.py | 1 - src/petals/server/server.py | 27 ++++++++++++--------------- 4 files changed, 18 insertions(+), 32 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 3d48d37..37edb8f 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -86,7 +86,7 @@ jobs: python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log & + --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 5e7efb5..fb521ef 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -7,7 +7,7 @@ from hivemind.utils.logging import get_logger from humanfriendly import parse_size from petals.constants import PUBLIC_INITIAL_PEERS -from petals.server.server import Server +from petals.server.server import DTYPE_MAP, Server from petals.utils.version import validate_version logger = get_logger(__name__) @@ -78,14 +78,12 @@ def main(): parser.add_argument('--device', type=str, default=None, required=False, help='all blocks will use this device in torch notation; default: cuda if available else cpu') - parser.add_argument("--torch_dtype", type=str, default="auto", + parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") - parser.add_argument('--attn_cache_size', type=str, default=None, - help='The size of GPU memory allocated for storing past attention keys/values between inference steps. ' - 'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. ' - 'Default: 0.5GiB * num_blocks * hidden_size / 14336. ' - 'The latter is the hidden size of the bigscience/bloom-petals model.') + parser.add_argument('--attn_cache_tokens', type=int, default=8192, + help='The number of past attention key/value pairs that will be stored between inference steps. ' + 'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).') parser.add_argument('--alloc_timeout', type=float, default=60, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') @@ -178,13 +176,6 @@ def main(): compression_type = args.pop("compression").upper() compression = getattr(CompressionType, compression_type) - attn_cache_size = args.pop("attn_cache_size") - if attn_cache_size is not None: - attn_cache_size = parse_size(attn_cache_size) - assert isinstance( - attn_cache_size, (int, type(None)) - ), "Unrecognized value for --attn_cache_size. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)" - max_disk_space = args.pop("max_disk_space") if max_disk_space is not None: max_disk_space = parse_size(max_disk_space) @@ -207,7 +198,6 @@ def main(): announce_maddrs=announce_maddrs, compression=compression, max_disk_space=max_disk_space, - attn_cache_size=attn_cache_size, ) try: server.run() diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index aae181e..76dc52b 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -48,7 +48,6 @@ class TransformerBackend(ModuleBackend): self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" ) - assert backend_dtype is not None self.dtype = backend_dtype self.shard_num_heads = [] for shard in self.module.module_shards: diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 2d666ae..e424fb5 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -56,7 +56,7 @@ class Server: revision: str = "main", cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, - attn_cache_size: Optional[int] = None, + attn_cache_tokens: int = 8192, alloc_timeout: float = 60, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, @@ -148,9 +148,7 @@ class Server: device = torch.device(device.type, index=0) self.device = device - if isinstance(torch_dtype, str): - torch_dtype = DTYPE_MAP[torch_dtype] - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + torch_dtype = DTYPE_MAP[torch_dtype] self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype) if tensor_parallel_devices is None: @@ -165,6 +163,9 @@ class Server: self.load_in_8bit = load_in_8bit logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") + max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens + self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8 + assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() @@ -179,13 +180,10 @@ class Server: self.strict_block_indices, self.num_blocks = block_indices, num_blocks gib = 1024**3 - if attn_cache_size is None: - # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly - attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336 - - self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout - logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB") + self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks + logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") + self.alloc_timeout = alloc_timeout if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR self.cache_dir = cache_dir @@ -236,10 +234,9 @@ class Server: # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models gib = 1024**3 - attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward - num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block)) + num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" logger.info( @@ -256,7 +253,7 @@ class Server: prefix=self.prefix, converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, - attn_cache_size=self.attn_cache_size, + attn_cache_bytes=self.attn_cache_bytes, alloc_timeout=self.alloc_timeout, throughput=self.throughput, block_indices=block_indices, @@ -356,7 +353,7 @@ class ModuleContainer(threading.Thread): prefix: str, converted_model_name_or_path: str, block_config: BloomConfig, - attn_cache_size: int, + attn_cache_bytes: int, alloc_timeout: float, throughput: float, block_indices: List[int], @@ -390,7 +387,7 @@ class ModuleContainer(threading.Thread): assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) - memory_cache = MemoryCache(attn_cache_size, alloc_timeout) + memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) blocks = {} try: for module_uid, block_index in zip(module_uids, block_indices): From cb3f018f9f0362ff4d2aa77c6950c1b6aabcdc43 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 23 Jun 2023 15:46:10 +0400 Subject: [PATCH 090/168] Add LLaMA support (#323) This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present). --- .github/workflows/run-tests.yaml | 66 +---- setup.cfg | 4 +- src/petals/__init__.py | 12 +- src/petals/bloom/__init__.py | 0 src/petals/bloom/block.py | 62 ---- src/petals/bloom/from_pretrained.py | 132 --------- src/petals/cli/config.json | 20 -- src/petals/cli/convert_model.py | 96 ------- src/petals/cli/inference_one_block.py | 2 +- src/petals/cli/run_server.py | 2 +- src/petals/client/__init__.py | 6 - src/petals/client/from_pretrained.py | 94 ++++++ .../modeling_utils.py => client/lm_head.py} | 72 ++--- src/petals/client/ptune.py | 88 ++++++ src/petals/client/remote_model.py | 268 ------------------ src/petals/client/remote_sequential.py | 7 +- src/petals/client/routing/sequence_manager.py | 9 +- src/petals/models/__init__.py | 2 + src/petals/models/bloom/__init__.py | 7 + src/petals/models/bloom/block.py | 32 +++ src/petals/models/bloom/config.py | 35 +++ src/petals/models/bloom/model.py | 134 +++++++++ src/petals/models/llama/__init__.py | 7 + src/petals/models/llama/block.py | 87 ++++++ src/petals/models/llama/config.py | 35 +++ src/petals/models/llama/model.py | 152 ++++++++++ src/petals/server/backend.py | 21 +- src/petals/server/block_utils.py | 10 +- src/petals/server/from_pretrained.py | 175 ++++++++++++ src/petals/server/server.py | 64 +++-- src/petals/server/throughput.py | 22 +- src/petals/utils/__init__.py | 1 + src/petals/utils/auto_config.py | 23 ++ src/petals/utils/convert_block.py | 28 +- src/petals/utils/disk_cache.py | 8 +- src/petals/utils/version.py | 20 +- tests/test_aux_functions.py | 4 +- tests/test_block_exact_match.py | 70 +---- tests/test_chained_calls.py | 4 +- tests/test_dtype.py | 15 +- tests/test_full_model.py | 4 +- tests/test_remote_sequential.py | 18 +- tests/test_sequence_manager.py | 4 +- tests/test_server_stats.py | 2 +- tests/test_tensor_parallel.py | 2 +- 45 files changed, 1073 insertions(+), 853 deletions(-) delete mode 100644 src/petals/bloom/__init__.py delete mode 100644 src/petals/bloom/block.py delete mode 100644 src/petals/bloom/from_pretrained.py delete mode 100644 src/petals/cli/config.json delete mode 100644 src/petals/cli/convert_model.py create mode 100644 src/petals/client/from_pretrained.py rename src/petals/{bloom/modeling_utils.py => client/lm_head.py} (53%) create mode 100644 src/petals/client/ptune.py delete mode 100644 src/petals/client/remote_model.py create mode 100644 src/petals/models/__init__.py create mode 100644 src/petals/models/bloom/__init__.py create mode 100644 src/petals/models/bloom/block.py create mode 100644 src/petals/models/bloom/config.py create mode 100644 src/petals/models/bloom/model.py create mode 100644 src/petals/models/llama/__init__.py create mode 100644 src/petals/models/llama/block.py create mode 100644 src/petals/models/llama/config.py create mode 100644 src/petals/models/llama/model.py create mode 100644 src/petals/server/from_pretrained.py create mode 100644 src/petals/utils/auto_config.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 37edb8f..fbb5b72 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -6,57 +6,8 @@ on: pull_request: jobs: - convert-model: - runs-on: ubuntu-latest - env: - BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }} - timeout-minutes: 15 - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Check if the model is cached - id: cache-model - uses: actions/cache@v3 - with: - path: ~/converted_ok - key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - - name: Set up Python - if: steps.cache-model.outputs.cache-hit != 'true' - uses: actions/setup-python@v3 - with: - python-version: 3.9 - - name: Cache dependencies - if: steps.cache-model.outputs.cache-hit != 'true' - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: Key-v1-3.9-${{ hashFiles('setup.cfg') }} - - name: Install dependencies - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - python -m pip install --upgrade pip - pip install . - - name: Delete any test models older than 1 week - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN - - name: Delete previous version of this model, if exists - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") - python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \ - repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true - - name: Convert model and push to hub - if: steps.cache-model.outputs.cache-hit != 'true' - run: | - export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ - --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \ - --resize_token_embeddings 50000 && touch ~/converted_ok - run-tests: runs-on: ubuntu-latest - needs: convert-model strategy: matrix: python-version: [ '3.7', '3.8', '3.9', '3.10' ] @@ -80,8 +31,7 @@ jobs: pip install .[dev] - name: Test run: | - export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG + export MODEL_NAME=bigscience/bloom-560m export REF_NAME=bigscience/bloom-560m python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ @@ -104,23 +54,19 @@ jobs: --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log & SERVER3_PID=$! - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \ - --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log & - SERVER4_PID=$! - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu --torch_dtype float32 &> server5.log & - SERVER5_PID=$! + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log & + SERVER4_PID=$! tail -n 100 -f server*.log & LOGGER_PID=$! sleep 30 # wait for servers to download layers - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init + kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init pytest tests --durations=0 --durations-min=1.0 -v - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests + kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests - kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID $LOGGER_PID + kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID echo "Done!" diff --git a/setup.cfg b/setup.cfg index 8c237aa..4722c63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,8 @@ install_requires = bitsandbytes==0.38.0.post2 accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 - transformers>=4.25.1,<5.0.0 + tokenizers>=0.13.3 + transformers>=4.30.1,<5.0.0 speedtest-cli==2.1.3 hivemind==1.1.8 tensor_parallel==1.0.23 @@ -43,6 +44,7 @@ install_requires = async-timeout>=4.0.2 cpufeature>=0.2.0 packaging>=20.9 + sentencepiece>=0.1.99 [options.extras_require] dev = diff --git a/src/petals/__init__.py b/src/petals/__init__.py index b50b251..26aa3ab 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,11 +1,21 @@ import os import hivemind +import transformers +from packaging import version from petals.client import * +from petals.models import * +from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.1.5" +__version__ = "1.2.0.dev0" + + +if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): + assert ( + version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0") + ), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0" def _override_bfloat16_mode_default(): diff --git a/src/petals/bloom/__init__.py b/src/petals/bloom/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/petals/bloom/block.py b/src/petals/bloom/block.py deleted file mode 100644 index 9037ee4..0000000 --- a/src/petals/bloom/block.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Bloom intermediate layer -Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b -See commit history for authorship. -""" -import os -from typing import Optional, Tuple - -import torch.nn.quantized.dynamic.modules.linear -import transformers -from packaging import version -from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor - -if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): - assert ( - version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0") - ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0" - - -class WrappedBloomBlock(BloomBlock): - def forward( - self, - hidden_states: torch.Tensor, - *args, - attention_mask: Optional[torch.Tensor] = None, - alibi: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs - ): - assert attention_mask is None - batch_size, seq_length = hidden_states.shape[:2] - past_length = 0 if layer_past is None else layer_past[0].shape[-1] - seq_length_with_past = seq_length + past_length - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - if alibi is None: - alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) - return super().forward( - hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs - ) - - def _prepare_attn_mask( - self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int - ) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = _make_causal_mask( - torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py deleted file mode 100644 index d40b01f..0000000 --- a/src/petals/bloom/from_pretrained.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code. -If necessary, one can rewrite this to implement a different behavior, such as: - - loading files from a local data source (e.g. S3) - - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to ) - - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html ) - -""" -from __future__ import annotations - -import itertools -import time -from typing import Optional, OrderedDict, Union - -import torch -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from hivemind.utils.logging import get_logger -from transformers.modeling_utils import WEIGHTS_NAME -from transformers.models.bloom.configuration_bloom import BloomConfig -from transformers.utils import get_file_from_repo - -from petals.bloom.block import WrappedBloomBlock -from petals.server.block_utils import get_block_size, resolve_block_dtype -from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for - -logger = get_logger(__name__) - -CLIENT_BRANCH = "main" -BLOCK_BRANCH_PREFIX = "block_" - - -def load_pretrained_block( - converted_model_name_or_path: str, - block_index: int, - config: Optional[BloomConfig] = None, - torch_dtype: Union[torch.dtype, str] = "auto", - use_auth_token: Optional[str] = None, - cache_dir: Optional[str] = None, - max_disk_space: Optional[int] = None, -) -> WrappedBloomBlock: - """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it.""" - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - torch_dtype = resolve_block_dtype(config, torch_dtype) - - if config is None: - config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) - if cache_dir is None: - cache_dir = DEFAULT_CACHE_DIR - - with init_empty_weights(): - block = WrappedBloomBlock(config) - - state_dict = _load_state_dict( - converted_model_name_or_path, - block_index, - config, - use_auth_token=use_auth_token, - cache_dir=cache_dir, - max_disk_space=max_disk_space, - ) - - # dummy load, check that keys match - report = block.load_state_dict(state_dict, strict=True) - assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" - - for param_name, _ in block.named_parameters(): - assert param_name in state_dict, f"{param_name} not in state dict" - param = state_dict[param_name] - if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): - param = param.to(torch_dtype) - set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) - - logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}") - return block - - -def _load_state_dict( - pretrained_model_name_or_path: str, - block_index: int, - config: BloomConfig, - *, - use_auth_token: Optional[str] = None, - cache_dir: str, - max_disk_space: Optional[int] = None, - min_backoff: float = 5, -) -> OrderedDict[str, torch.Tensor]: - revision = BLOCK_BRANCH_PREFIX + str(block_index) - - # First, try to find the weights locally - try: - with allow_cache_reads(cache_dir): - archive_file = get_file_from_repo( - pretrained_model_name_or_path, - filename=WEIGHTS_NAME, - revision=revision, - use_auth_token=use_auth_token, - cache_dir=cache_dir, - local_files_only=True, - ) - if archive_file is not None: - return torch.load(archive_file, map_location="cpu") - except Exception: - logger.debug( - f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True - ) - - # If not found, ensure that we have enough disk space to download them (maybe remove something) - for attempt_no in itertools.count(): - try: - with allow_cache_writes(cache_dir): - block_size = get_block_size(config, "disk") - free_disk_space_for( - pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space - ) - - archive_file = get_file_from_repo( - pretrained_model_name_or_path, - filename=WEIGHTS_NAME, - revision=revision, - use_auth_token=use_auth_token, - cache_dir=cache_dir, - local_files_only=False, - ) - return torch.load(archive_file, map_location="cpu") - except Exception as e: - delay = min_backoff * (2**attempt_no) - logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) - time.sleep(delay) - - -DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/cli/config.json b/src/petals/cli/config.json deleted file mode 100644 index ca7ffbb..0000000 --- a/src/petals/cli/config.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "apply_residual_connection_post_layernorm": false, - "attention_dropout": 0.0, - "attention_softmax_in_fp32": true, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_dropout": 0.0, - "initializer_range": 0.02, - "layer_norm_epsilon": 1e-05, - "masked_softmax_fusion": true, - "model_type": "bloom", - "n_embed": 14336, - "n_layer": 70, - "num_attention_heads": 112, - "pretraining_tp": 4, - "slow_but_exact": false, - "transformers_version": "4.20.0.dev0", - "use_cache": true, - "vocab_size": 250880 -} \ No newline at end of file diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py deleted file mode 100644 index 15e12b7..0000000 --- a/src/petals/cli/convert_model.py +++ /dev/null @@ -1,96 +0,0 @@ -import argparse -import os - -import psutil -import torch.backends.quantized -import torch.nn as nn -import transformers -from hivemind.utils.logging import get_logger -from huggingface_hub import HfApi, Repository -from tqdm.auto import tqdm -from transformers.models.bloom.modeling_bloom import BloomModel - -from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP -from petals.client import DistributedBloomConfig - -logger = get_logger(__name__) - - -def main(): - parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.") - - parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained") - parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub") - parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype") - parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder") - parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo") - parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch") - parser.add_argument( - "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix" - ) - parser.add_argument( - "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts" - ) - parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") - parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size") - args = parser.parse_args() - - free_ram_gb = psutil.virtual_memory().available / 2**30 - if args.model == "bigscience/bloom" and free_ram_gb < 400: - logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free") - - assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}" - if os.path.exists(args.output_path) and ( - len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path) - ): - raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory") - - logger.info(f"Loading source model {args.model} (this may take a few minutes)") - config = DistributedBloomConfig.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision - ) - config.dht_prefix = args.output_repo - - model = BloomModel.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype] - ) - if args.resize_token_embeddings: - logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}") - model.resize_token_embeddings(args.resize_token_embeddings) - config.vocab_size = args.resize_token_embeddings - - tokenizer = transformers.AutoTokenizer.from_pretrained( - args.model, use_auth_token=args.use_auth_token, revision=args.revision - ) - os.makedirs(args.output_path, exist_ok=True) - - api = HfApi(token=args.use_auth_token) - api.create_repo(args.output_repo, repo_type="model", exist_ok=True) - repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token) - repo.git_pull() - - transformer_blocks = model.h - logger.info( - f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0" - f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}" - ) - for i, block in enumerate(tqdm(transformer_blocks)): - repo.git_checkout(args.client_branch, create_branch_ok=True) - with repo.commit( - commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True - ): - torch.save(block.state_dict(), "./pytorch_model.bin") - - logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}") - repo.git_checkout(args.client_branch, create_branch_ok=True) - with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True): - model.h = nn.ModuleList() - model.save_pretrained(".") - tokenizer.save_pretrained(".") - config.save_pretrained(".") - - logger.info(f"Converted {args.model} and pushed to {args.output_repo}") - - -if __name__ == "__main__": - main() diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index 01ba1ef..6d53e9b 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -6,7 +6,7 @@ from tqdm.auto import trange from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from petals.bloom.block import BloomBlock +from petals.models.bloom.block import BloomBlock logger = get_logger(__name__) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index fb521ef..4c6f0e5 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -87,7 +87,7 @@ def main(): parser.add_argument('--alloc_timeout', type=float, default=60, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') - parser.add_argument('--revision', type=str, default='main', + parser.add_argument('--revision', type=str, default=None, help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") diff --git a/src/petals/client/__init__.py b/src/petals/client/__init__.py index 5ff26bc..f80c4b1 100644 --- a/src/petals/client/__init__.py +++ b/src/petals/client/__init__.py @@ -1,10 +1,4 @@ from petals.client.inference_session import InferenceSession -from petals.client.remote_model import ( - DistributedBloomConfig, - DistributedBloomForCausalLM, - DistributedBloomForSequenceClassification, - DistributedBloomModel, -) from petals.client.remote_sequential import RemoteSequential from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/petals/client/from_pretrained.py b/src/petals/client/from_pretrained.py new file mode 100644 index 0000000..b8d02c0 --- /dev/null +++ b/src/petals/client/from_pretrained.py @@ -0,0 +1,94 @@ +import contextlib +import json +import os +import re +import tempfile +import threading +from typing import List, Optional, Tuple, Union + +import torch +from hivemind.utils.logging import get_logger +from transformers import BloomPreTrainedModel, modeling_utils + +from petals.utils.version import get_compatible_model_repo + +logger = get_logger(__name__) + + +class FromPretrainedMixin: + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, os.PathLike, None], + *args, + low_cpu_mem_usage: Optional[bool] = None, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + **kwargs, + ): + model_name_or_path = get_compatible_model_repo(model_name_or_path) + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + if torch_dtype is None: + # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, + # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. + torch_dtype = "auto" + + with ignore_keys(cls._keys_to_ignore_on_load_unexpected): + return super().from_pretrained( + model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs + ) + + from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( + "low_cpu_mem_usage(`bool`, *optional*)", + "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", + ).replace( + "torch_dtype (`str` or `torch.dtype`, *optional*)", + 'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)', + ) + + +_shard_config = threading.local() +_shard_config.ignored_keys = None + + +@contextlib.contextmanager +def ignore_keys(patterns: List[str]): + try: + prev_patterns = _shard_config.ignored_keys + _shard_config.ignored_keys = patterns + yield + finally: + _shard_config.ignored_keys = prev_patterns + + +def patched_get_checkpoint_shard_files( + pretrained_model_name_or_path, index_filename, *args, **kwargs +) -> Tuple[List[str], dict]: + """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys.""" + + should_ignore_keys = _shard_config.ignored_keys is not None + tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext() + with tempdir_ctx as tempdir: + if should_ignore_keys: + with open(index_filename) as f: + index = json.load(f) + n_original_shards = len(set(index["weight_map"].values())) + + index["weight_map"] = { + param_name: filename + for param_name, filename in index["weight_map"].items() + if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys) + } + n_loaded_shards = len(set(index["weight_map"].values())) + logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}") + + # Replace the original index with a patched JSON, where ignored keys are removed + index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json") + with open(index_filename, "w") as f: + json.dump(index, f) + + return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) + + +original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files +modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/client/lm_head.py similarity index 53% rename from src/petals/bloom/modeling_utils.py rename to src/petals/client/lm_head.py index eddbb9d..ddd2887 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/client/lm_head.py @@ -1,10 +1,6 @@ -""" -PyTorch BLOOM model that implements several memory-efficient modes. -Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b -See commit history for authorship. -""" - +import dataclasses import platform +from typing import Optional, Union import psutil import torch @@ -12,21 +8,30 @@ import torch.nn.functional as F import torch.utils.checkpoint from hivemind import get_logger from torch import nn -from transformers import BloomConfig +from transformers import PretrainedConfig logger = get_logger(__name__) -class LMHead(nn.Module): - """ - The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input - embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. - In addition, it provides an effcient way to deal with half-precision word embeddings on CPU. - """ +@dataclasses.dataclass +class LMHeadConfig: + # This settings matter for running the client with dtype bfloat16 on CPU. + # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. + use_chunked_forward: Union[str, bool] = "auto" + chunked_forward_step: int = 16384 + - def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding): +class LMHead(nn.Module): + def __init__(self, config: PretrainedConfig): super().__init__() - self.word_embeddings = word_embeddings + + if not config.tie_word_embeddings: + self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False)) + else: + self.weight = None # Will be set to get_input_embeddings().weight during loading the model + self.bias = None + self.in_features = config.hidden_size # Similar to nn.Linear attributes + self.out_features = config.vocab_size self.use_chunked_forward = config.use_chunked_forward if self.use_chunked_forward == "auto": @@ -42,35 +47,17 @@ class LMHead(nn.Module): self.chunked_forward_step = config.chunked_forward_step self._bf16_warning_shown = False - @property - def in_features(self) -> int: - return self.word_embeddings.num_embeddings - - @property - def out_features(self) -> int: - return self.word_embeddings.embedding_dim - - @property - def weight(self): - return self.word_embeddings.weight - - @property - def bias(self): - return None - def forward(self, hidden_states): - word_embeddings = self.word_embeddings.weight - if ( - word_embeddings.dtype in [torch.float16, torch.bfloat16] - and word_embeddings.device.type == "cpu" + self.weight.dtype in [torch.float16, torch.bfloat16] + and self.weight.device.type == "cpu" and self.use_chunked_forward ): lm_logits = self.chunked_forward(hidden_states) else: # Switch dtype in case word_embeddings are fp16/bf16 - hidden_states = hidden_states.to(word_embeddings.dtype) - lm_logits = F.linear(hidden_states, word_embeddings) + hidden_states = hidden_states.to(self.weight.dtype) + lm_logits = F.linear(hidden_states, self.weight) return lm_logits def chunked_forward(self, hidden_states): @@ -80,20 +67,17 @@ class LMHead(nn.Module): assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" if not self._bf16_warning_shown: - if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: + if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: logger.warning( "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. " "Consider loading the model with torch_dtype='float32'" ) self._bf16_warning_shown = True - word_embeddings = self.word_embeddings.weight - num_embeddings = self.word_embeddings.num_embeddings - hidden_states = hidden_states.float() - output = torch.empty(*hidden_states.shape[:-1], num_embeddings) + output = torch.empty(*hidden_states.shape[:-1], self.out_features) - for i in range(0, num_embeddings, self.chunked_forward_step): - chunk = word_embeddings[i : i + self.chunked_forward_step].float() + for i in range(0, self.out_features, self.chunked_forward_step): + chunk = self.weight[i : i + self.chunked_forward_step].float() output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk) return output diff --git a/src/petals/client/ptune.py b/src/petals/client/ptune.py new file mode 100644 index 0000000..5cf613c --- /dev/null +++ b/src/petals/client/ptune.py @@ -0,0 +1,88 @@ +import dataclasses +from contextlib import contextmanager +from typing import Optional + +import torch +import torch.nn as nn +from hivemind import get_logger +from transformers import PretrainedConfig + +from petals.utils.misc import DUMMY + +logger = get_logger(__name__) + + +@dataclasses.dataclass +class PTuneConfig: + pre_seq_len: int = 0 # a number of tokens for prompt tuning. + tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"] + + +class PTuneMixin: + _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"] + + def init_prompts(self, config: PretrainedConfig) -> None: + if config.tuning_mode and "ptune" in config.tuning_mode: + assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0" + self.pre_seq_len = config.pre_seq_len + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + + with force_non_empty_weights(): + # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality + self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32) + if config.tuning_mode == "deep_ptune": + self.intermediate_prompt_embeddings = nn.Embedding( + self.pre_seq_len, + config.num_hidden_layers * config.hidden_size, + # ^-- TODO: should be num_hidden_layers - 1 + dtype=torch.float32, + ) + elif config.tuning_mode: + raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") + + def set_requires_grad(self, value): + for p in self.parameters(): + p.requires_grad = value + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) + prompts = self.prompt_embeddings(prefix_tokens) + + if self.config.tuning_mode == "deep_ptune": + intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens) + intermediate_prompts = intermediate_prompts.view( + batch_size, + self.pre_seq_len, + self.config.num_hidden_layers, + self.config.hidden_size + # TODO: should be num_hidden_layers - 1 + ) + intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3]) + else: + intermediate_prompts = DUMMY + + dtype = self.word_embeddings.weight.dtype + return prompts.to(dtype), intermediate_prompts.to(dtype) + + +_original_register_parameter = nn.Module.register_parameter + + +@contextmanager +def force_non_empty_weights(): + """ + This context manager allows to bypass the accelerate.init_empty_weights() context manager + (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True. + The transformers library should replace all meta tensors by empty tensors by itself + but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`). + + [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515 + """ + + try: + possibly_patched_register_parameter = nn.Module.register_parameter + nn.Module.register_parameter = _original_register_parameter + yield + finally: + nn.Module.register_parameter = possibly_patched_register_parameter diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py deleted file mode 100644 index b556714..0000000 --- a/src/petals/client/remote_model.py +++ /dev/null @@ -1,268 +0,0 @@ -from contextlib import contextmanager -from typing import List, Optional, Union - -import hivemind -import torch -import torch.nn as nn -from hivemind.utils.logging import get_logger -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bloom import ( - BloomConfig, - BloomForCausalLM, - BloomForSequenceClassification, - BloomModel, - BloomPreTrainedModel, -) - -from petals.bloom.modeling_utils import LMHead -from petals.client.remote_generation import RemoteGenerationMixin -from petals.client.remote_sequential import RemoteSequential -from petals.client.routing.sequence_manager import SequenceManagerConfig -from petals.constants import PUBLIC_INITIAL_PEERS -from petals.utils.misc import DUMMY - -logger = get_logger(__name__) - - -class DistributedBloomConfig(BloomConfig, SequenceManagerConfig): - """ - A bloom config that contains information about DHT peers. - To create a distributed model, one must provide dht_prefix and either initial_peers or dht. - """ - - initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT - dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) - daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers - - pre_seq_len: int = 0 # a number of tokens for prompt tuning. - tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"] - - # This settings matter for running the client with dtype bfloat16 on CPU. - # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. - use_chunked_forward: Union[str, bool] = "auto" - chunked_forward_step: int = 16384 - - -original_register_parameter = nn.Module.register_parameter - - -@contextmanager -def force_non_empty_weights(): - """ - This context manager allows to bypass the accelerate.init_empty_weights() context manager - (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True. - The transformers library should replace all meta tensors by empty tensors by itself - but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`). - - [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515 - """ - - try: - possibly_patched_register_parameter = nn.Module.register_parameter - nn.Module.register_parameter = original_register_parameter - yield - finally: - nn.Module.register_parameter = possibly_patched_register_parameter - - -class _FromPretrainedDefaultsMixin: - @classmethod - def from_pretrained( - cls, - *args, - low_cpu_mem_usage: Optional[bool] = None, - torch_dtype: Optional[Union[str, torch.dtype]] = None, - **kwargs, - ): - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - if torch_dtype is None: - # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, - # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. - torch_dtype = "auto" - return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs) - - from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( - "low_cpu_mem_usage(`bool`, *optional*)", - "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", - ).replace( - "torch_dtype (`str` or `torch.dtype`, *optional*)", - 'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)', - ) - - -class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel): - """BloomModel, but all transformer layers are hosted by the swarm""" - - _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [ - r"^(intermediate_)?prompt_embeddings\.weight$", - ] - - config_class = DistributedBloomConfig - - def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None): - assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." - assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" - - n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization - super().__init__(config) - assert len(self.h) == 0 - config.n_layer = n_layer - - self.h = RemoteSequential(config, dht=dht) - - # Forbid accumulate grads for embeddings and layernorm - self.set_requires_grad(False) - - if config.tuning_mode and "ptune" in config.tuning_mode: - assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0" - self.pre_seq_len = config.pre_seq_len - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - - with force_non_empty_weights(): - if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16): - logger.info( - "Prompt embeddings and their optimizer statistics will be kept in float32 " - "to increase ptune quality" - ) - self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32) - if config.tuning_mode == "deep_ptune": - self.intermediate_prompt_embeddings = nn.Embedding( - self.pre_seq_len, - config.num_hidden_layers * config.hidden_size, - # ^-- TODO: should be num_hidden_layers - 1 - dtype=torch.float32, - ) - elif config.tuning_mode: - raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") - - def set_requires_grad(self, value): - for p in self.parameters(): - p.requires_grad = value - - def get_prompt(self, batch_size): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) - prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) - prompts = self.prompt_embeddings(prefix_tokens) - - if self.config.tuning_mode == "deep_ptune": - intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens) - intermediate_prompts = intermediate_prompts.view( - batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1 - ) - intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3]) - else: - intermediate_prompts = DUMMY - - dtype = self.word_embeddings.weight.dtype - return prompts.to(dtype), intermediate_prompts.to(dtype) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - assert attention_mask is None, "DistributedBloomModel does not support attention masks right now" - - for k, v in kwargs.items(): - if not (v is None or v is False): - logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - batch_size = inputs_embeds.shape[0] - prompts, intermediate_prompts = self.get_prompt(batch_size) - inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - output_shape = input_shape + (hidden_states.size(-1),) - - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - hidden_states = self.h(hidden_states, prompts=intermediate_prompts) - else: - hidden_states = self.h(hidden_states) - - # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - hidden_states = hidden_states[:, self.pre_seq_len :] - - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM): - """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" - - _keys_to_ignore_on_load_missing = ( - BloomForCausalLM._keys_to_ignore_on_load_missing - + DistributedBloomModel._keys_to_ignore_on_load_missing - + [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings - ) - - config_class = DistributedBloomConfig - - def __init__(self, config: DistributedBloomConfig): - BloomPreTrainedModel.__init__(self, config) - self.transformer = DistributedBloomModel(config) - self.lm_head = LMHead(config, self.transformer.word_embeddings) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.transformer.word_embeddings - - def get_output_embeddings(self): - if self.config.tie_word_embeddings: - return None - return self.lm_head - - def set_input_embeddings(self, new_embeddings: nn.Embedding): - assert isinstance(new_embeddings, nn.Embedding) - self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings - assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings - - def set_output_embeddings(self, new_lm_head: nn.Linear): - with torch.no_grad(): - self.lm_head.word_embeddings.weight[...] = new_lm_head.weight - self.lm_head.bias[...] = new_lm_head.bias - - -class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification): - _keys_to_ignore_on_load_missing = ( - BloomForSequenceClassification._keys_to_ignore_on_load_missing - + DistributedBloomModel._keys_to_ignore_on_load_missing - ) - - config_class = DistributedBloomConfig - - def __init__(self, config: DistributedBloomConfig): - BloomPreTrainedModel.__init__(self, config) - self.num_labels = config.num_labels - - self.transformer = DistributedBloomModel(config) - self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) - - # Initialize weights and apply final processing - self.post_init() diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 39811e3..745b5c1 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -6,9 +6,8 @@ import torch from hivemind import DHT, get_logger from torch import nn -import petals.client from petals.client.inference_session import InferenceSession -from petals.client.routing.sequence_manager import RemoteSequenceManager +from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER from petals.utils.misc import DUMMY @@ -23,7 +22,7 @@ class RemoteSequential(nn.Module): def __init__( self, - config: petals.client.DistributedBloomConfig, + config: SequenceManagerConfig, *, sequence_manager: Optional[RemoteSequenceManager] = None, dht: Optional[DHT] = None, @@ -40,7 +39,7 @@ class RemoteSequential(nn.Module): if start_block is None: start_block = 0 if end_block is None: - end_block = self.config.n_layer + end_block = self.config.num_hidden_layers block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) self.sequence_manager = sequence_manager diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 6ac7bb0..1a31d66 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -20,6 +20,7 @@ from hivemind.utils.logging import get_logger import petals.dht_utils from petals.client.routing.sequence_info import RemoteSequenceInfo from petals.client.routing.spending_policy import NoSpendingPolicy +from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler @@ -28,6 +29,10 @@ logger = get_logger(__name__) @dataclasses.dataclass class SequenceManagerConfig: + initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT + dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name) + daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers + allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests @@ -73,6 +78,8 @@ class RemoteSequenceManager: dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, ): + assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" + assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." assert len(block_uids) > 0, "Sequences must contain at least one block" self.config = config @@ -84,7 +91,7 @@ class RemoteSequenceManager: dht = DHT( initial_peers=config.initial_peers, client_mode=True, - num_workers=config.n_layer, + num_workers=config.num_hidden_layers, startup_timeout=config.daemon_startup_timeout, start=True, ) diff --git a/src/petals/models/__init__.py b/src/petals/models/__init__.py new file mode 100644 index 0000000..acb4d38 --- /dev/null +++ b/src/petals/models/__init__.py @@ -0,0 +1,2 @@ +from petals.models.bloom import * +from petals.models.llama import * diff --git a/src/petals/models/bloom/__init__.py b/src/petals/models/bloom/__init__.py new file mode 100644 index 0000000..911974b --- /dev/null +++ b/src/petals/models/bloom/__init__.py @@ -0,0 +1,7 @@ +from petals.models.bloom.block import WrappedBloomBlock +from petals.models.bloom.config import DistributedBloomConfig +from petals.models.bloom.model import ( + DistributedBloomForCausalLM, + DistributedBloomForSequenceClassification, + DistributedBloomModel, +) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py new file mode 100644 index 0000000..f246bd8 --- /dev/null +++ b/src/petals/models/bloom/block.py @@ -0,0 +1,32 @@ +""" +Bloom intermediate layer +Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b +See commit history for authorship. +""" +from typing import Optional, Tuple + +import torch +from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor + + +class WrappedBloomBlock(BloomBlock): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs + ): + assert attention_mask is None, "Non-causal attention masks are not supported yet" + batch_size, seq_length = hidden_states.shape[:2] + past_length = 0 if layer_past is None else layer_past[0].shape[-1] + seq_length_with_past = seq_length + past_length + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length) + return super().forward( + hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs + ) diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py new file mode 100644 index 0000000..57c3e7b --- /dev/null +++ b/src/petals/models/bloom/config.py @@ -0,0 +1,35 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.bloom import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention + +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.client.routing.sequence_manager import SequenceManagerConfig +from petals.models.bloom.block import WrappedBloomBlock +from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.version import get_compatible_model_repo + +logger = get_logger(__name__) + + +class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedBloomBlock + attn_class = BloomAttention + block_prefix = "h" + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + # We need "-petals" for backward compatibility with Petals < 1.2.0 + dht_prefix = str(model_name_or_path) + "-petals" + logger.info(f"Using DHT prefix: {dht_prefix}") + return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + + +AutoDistributedConfig.register(DistributedBloomConfig) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py new file mode 100644 index 0000000..fae9faf --- /dev/null +++ b/src/petals/models/bloom/model.py @@ -0,0 +1,134 @@ +from typing import Optional + +import hivemind +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin +from petals.client.remote_sequential import RemoteSequential +from petals.models.bloom.config import DistributedBloomConfig + +logger = get_logger(__name__) + + +class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): + """BloomModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = ( + BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing + ) + _keys_to_ignore_on_load_unexpected = [r"^h\."] + + config_class = DistributedBloomConfig + + def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.h) == 0 + config.num_hidden_layers = n_layer + + self.h = RemoteSequential(config, dht=dht) + + self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now" + + for k, v in kwargs.items(): + if not (v is None or v is False): + logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + output_shape = input_shape + (hidden_states.size(-1),) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = self.h(hidden_states, prompts=intermediate_prompts) + else: + hidden_states = self.h(hidden_states) + + # Remove prefix + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM): + _keys_to_ignore_on_load_missing = ( + BloomForCausalLM._keys_to_ignore_on_load_missing + + DistributedBloomModel._keys_to_ignore_on_load_missing + + [r"^lm_head\."] # Missing since they are shared with input embeddings + ) + _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedBloomConfig + + def __init__(self, config: DistributedBloomConfig): + BloomPreTrainedModel.__init__(self, config) + self.transformer = DistributedBloomModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + +class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification): + _keys_to_ignore_on_load_missing = ( + BloomForSequenceClassification._keys_to_ignore_on_load_missing + + DistributedBloomModel._keys_to_ignore_on_load_missing + ) + _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedBloomConfig + + def __init__(self, config: DistributedBloomConfig): + BloomPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.transformer = DistributedBloomModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) + + # Initialize weights and apply final processing + self.post_init() diff --git a/src/petals/models/llama/__init__.py b/src/petals/models/llama/__init__.py new file mode 100644 index 0000000..8156939 --- /dev/null +++ b/src/petals/models/llama/__init__.py @@ -0,0 +1,7 @@ +from petals.models.llama.block import WrappedLlamaBlock +from petals.models.llama.config import DistributedLlamaConfig +from petals.models.llama.model import ( + DistributedLlamaForCausalLM, + DistributedLlamaForSequenceClassification, + DistributedLlamaModel, +) diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py new file mode 100644 index 0000000..2f07188 --- /dev/null +++ b/src/petals/models/llama/block.py @@ -0,0 +1,87 @@ +""" +LLaMA intermediate layer +Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +See commit history for authorship. +""" +from typing import Optional, Tuple + +import torch +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + + +class WrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + batch_size, seq_length, _ = hidden_states.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + past_key_value = layer_past + if past_key_value is not None: + past_key_values_length = past_key_value[0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) + + if position_ids is None: + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = LlamaModel._prepare_decoder_attention_mask( + None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_llama_to_bloom( + present_key_value, batch_size, seq_length_with_past + ) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_llama( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + key_states = key_states.permute(0, 2, 1) + key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + value_states = value_states.view(*key_states.shape) + return (key_states, value_states) + + def _reorder_cache_from_llama_to_bloom( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + key_states = key_states.view(*value_states.shape) + key_states = key_states.permute(0, 2, 1) + return (key_states, value_states) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py new file mode 100644 index 0000000..a7e6681 --- /dev/null +++ b/src/petals/models/llama/config.py @@ -0,0 +1,35 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention + +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.client.routing.sequence_manager import SequenceManagerConfig +from petals.models.llama.block import WrappedLlamaBlock +from petals.utils.auto_config import AutoDistributedConfig + +logger = get_logger(__name__) + + +class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedLlamaBlock + attn_class = LlamaAttention + block_prefix = "model.layers" + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + dht_prefix = str(model_name_or_path) + if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts + dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :] + logger.info(f"Using DHT prefix: {dht_prefix}") + return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + + +AutoDistributedConfig.register(DistributedLlamaConfig) diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py new file mode 100644 index 0000000..37b4683 --- /dev/null +++ b/src/petals/models/llama/model.py @@ -0,0 +1,152 @@ +from typing import Optional + +import hivemind +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin +from petals.client.remote_sequential import RemoteSequential +from petals.models.llama.config import DistributedLlamaConfig + +logger = get_logger(__name__) + + +class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): + """LlamaModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."] + + config_class = DistributedLlamaConfig + + def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.layers) == 0 + config.num_hidden_layers = n_layer + + self.layers = RemoteSequential(config, dht=dht) + + self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now" + + for k, v in kwargs.items(): + if not (v is None or v is False): + logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + + hidden_states = inputs_embeds + output_shape = input_shape + (hidden_states.size(-1),) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = self.layers(hidden_states, prompts=intermediate_prompts) + else: + hidden_states = self.layers(hidden_states) + + # Remove prefix + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + @property + def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin + return self.embed_tokens + + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin + return nn.Identity() + + @property + def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin + return self.layers + + @property + def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin + return self.norm + + +class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM): + _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedLlamaConfig + + def __init__(self, config: DistributedLlamaConfig): + LlamaPreTrainedModel.__init__(self, config) + self.model = DistributedLlamaModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + @property + def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin + return self.model + + +class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification): + _keys_to_ignore_on_load_missing = ( + LlamaForSequenceClassification._keys_to_ignore_on_load_missing + + DistributedLlamaModel._keys_to_ignore_on_load_missing + ) + _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedLlamaConfig + + def __init__(self, config): + LlamaPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.model = DistributedLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @property + def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin + return self.model diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 76dc52b..adcd617 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -1,4 +1,3 @@ -"""Code for serving bloom blocks via hivemind-server""" from __future__ import annotations from collections import Counter @@ -12,8 +11,7 @@ from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger from tensor_parallel import TensorParallel from tensor_parallel.tensor_parallel import PerDeviceTensors -from transformers import BloomConfig -from transformers.models.bloom.modeling_bloom import BloomAttention +from transformers import PretrainedConfig from petals.data_structures import InferenceMetadata from petals.server.memory_cache import MemoryCache @@ -24,17 +22,19 @@ logger = get_logger(__name__) class TransformerBackend(ModuleBackend): - """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference""" + """A wrapper for a transformer block that can process requests for forward, backward and inference""" - def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): + def __init__( + self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs + ): super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) self.config = config self.memory_cache = memory_cache for name, param in self.module.named_parameters(): - assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" + assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" for name, buf in self.module.named_buffers(): - assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" + assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" max_batch_size = self.forward_pool.max_batch_size device = self.module.devices[self.module.output_device_index] @@ -52,9 +52,10 @@ class TransformerBackend(ModuleBackend): self.shard_num_heads = [] for shard in self.module.module_shards: for submodule in shard.modules(): - if isinstance(submodule, BloomAttention): + if isinstance(submodule, config.attn_class): self.shard_num_heads.append(submodule.num_heads) - assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head + assert len(self.shard_num_heads) == len(self.module.devices) + assert sum(self.shard_num_heads) == config.num_attention_heads self.inference_schema = ( ( @@ -71,7 +72,7 @@ class TransformerBackend(ModuleBackend): def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: """Create tensor descriptors for attention cache tensors used during inference_step""" - head_dim = self.config.hidden_size // self.config.n_head + head_dim = self.config.hidden_size // self.config.num_attention_heads cache_tensors = [] for device, num_heads in zip(self.module.devices, self.shard_num_heads): keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index 8d59d18..a6af3b0 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -2,12 +2,10 @@ from typing import Optional, Union import torch from accelerate import init_empty_weights -from transformers import BloomConfig +from transformers import PretrainedConfig -from petals.bloom.block import WrappedBloomBlock - -def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: +def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" if dtype not in ("auto", None): return dtype @@ -17,7 +15,7 @@ def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> def get_block_size( - config: BloomConfig, + config: PretrainedConfig, location: str, *, dtype: Optional[Union[str, torch.dtype]] = None, @@ -30,7 +28,7 @@ def get_block_size( ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' with init_empty_weights(include_buffers=True): - block = WrappedBloomBlock(config) + block = config.block_class(config) n_params = sum(param.numel() for param in block.parameters()) if location == "memory" and load_in_8bit: diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py new file mode 100644 index 0000000..aab8a9e --- /dev/null +++ b/src/petals/server/from_pretrained.py @@ -0,0 +1,175 @@ +""" +Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code. +If necessary, one can rewrite this to implement a different behavior, such as: + - loading files from a local data source (e.g. S3) + - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to ) + - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html ) + +""" +import json +import time +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from hivemind.utils.logging import get_logger +from huggingface_hub import get_hf_file_metadata, hf_hub_url +from transformers import PretrainedConfig +from transformers.utils import get_file_from_repo + +from petals.server.block_utils import resolve_block_dtype +from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for + +logger = get_logger(__name__) + + +def load_pretrained_block( + model_name: str, + block_index: int, + *, + config: Optional[PretrainedConfig] = None, + torch_dtype: Union[torch.dtype, str] = "auto", + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + max_disk_space: Optional[int] = None, +) -> nn.Module: + if config is None: + config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token) + if cache_dir is None: + cache_dir = DEFAULT_CACHE_DIR + + assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + torch_dtype = resolve_block_dtype(config, torch_dtype) + + with init_empty_weights(): + block = config.block_class(config) + + block_prefix = f"{config.block_prefix}.{block_index}." + state_dict = _load_state_dict_from_repo( + model_name, + block_prefix, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + max_disk_space=max_disk_space, + ) + + # dummy load, check that keys match + report = block.load_state_dict(state_dict, strict=True) + assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" + + for param_name, _ in block.named_parameters(): + assert param_name in state_dict, f"{param_name} not in state dict" + param = state_dict[param_name] + if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + param = param.to(torch_dtype) + set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) + + logger.info(f"Loaded {model_name} block {block_index}, {report}") + return block + + +StateDict = Dict[str, torch.Tensor] + + +def _load_state_dict_from_repo( + model_name: str, + block_prefix: str, + *, + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: str, + max_disk_space: Optional[int] = None, +) -> StateDict: + index_file = get_file_from_repo( + model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir + ) + if index_file is not None: # Sharded model + with open(index_file) as f: + index = json.load(f) + filenames = { + filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix) + } + if not filenames: + raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}") + else: # Non-sharded model + filenames = {"pytorch_model.bin"} + logger.debug(f"Loading {block_prefix}* from {filenames}") + + state_dict = {} + for filename in filenames: + shard_state_dict = _load_state_dict_from_file( + model_name, + filename, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + max_disk_space=max_disk_space, + ) + shard_state_dict = { + param_name[len(block_prefix) :]: param + for param_name, param in shard_state_dict.items() + if param_name.startswith(block_prefix) + } # Remove unused parameters from memory + state_dict.update(shard_state_dict) + return state_dict + + +def _load_state_dict_from_file( + model_name: str, + filename: str, + *, + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: str, + max_disk_space: Optional[int] = None, + delay: float = 30, +) -> StateDict: + # First, try to find the weights locally + try: + with allow_cache_reads(cache_dir): + path = get_file_from_repo( + model_name, + filename, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=True, + ) + if path is not None: + return torch.load(path, map_location="cpu") + except Exception: + logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True) + + # If not found, ensure that we have enough disk space to download them (maybe remove something) + while True: + try: + with allow_cache_writes(cache_dir): + url = hf_hub_url(model_name, filename, revision=revision) + file_size = get_hf_file_metadata(url, token=use_auth_token).size + if file_size is not None: + free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) + else: + logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}") + + path = get_file_from_repo( + model_name, + filename, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=False, + ) + if path is None: + raise RuntimeError(f"File {filename} does not exist in repo {model_name}") + return torch.load(path, map_location="cpu") + except Exception as e: + logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) + time.sleep(delay) + + +DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e424fb5..75a999e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -14,21 +14,23 @@ from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger -from transformers import BloomConfig +from transformers import PretrainedConfig -from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype +from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.throughput import get_dtype_name, get_server_throughput +from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR +from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -53,7 +55,7 @@ class Server: max_batch_size: int = 2048, inference_max_length: int = 2048, torch_dtype: str = "auto", - revision: str = "main", + revision: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, attn_cache_tokens: int = 8192, @@ -83,25 +85,32 @@ class Server: ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" + converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path) self.converted_model_name_or_path = converted_model_name_or_path + self.num_handlers = num_handlers self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.inference_max_length = inference_max_length self.compression = compression self.stats_report_interval, self.update_period = stats_report_interval, update_period self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads - self.use_auth_token = use_auth_token + self.revision, self.use_auth_token = revision, use_auth_token if custom_module_path is not None: add_custom_models_from_file(custom_module_path) + self.block_config = AutoDistributedConfig.from_pretrained( + converted_model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, + ) + if prefix is None: - prefix = converted_model_name_or_path - assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( - f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " - f"Please specify --prefix manually when starting a server" - ) - logger.debug(f"Automatic dht prefix: {prefix}") + prefix = self.block_config.dht_prefix + assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( + f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. " + f"Please specify another --prefix manually when starting a server" + ) self.prefix = prefix if expiration is None: @@ -111,12 +120,9 @@ class Server: self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout - self.block_config = BloomConfig.from_pretrained( - converted_model_name_or_path, - use_auth_token=use_auth_token, - revision=revision, - ) - self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] + self.module_uids = [ + f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers) + ] if dht_client_mode is None: is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) @@ -125,7 +131,7 @@ class Server: self.dht = DHT( initial_peers=initial_peers, start=True, - num_workers=self.block_config.n_layer, + num_workers=self.block_config.num_hidden_layers, use_relay=use_relay, use_auto_relay=use_auto_relay, client_mode=dht_client_mode, @@ -161,10 +167,10 @@ class Server: if load_in_8bit is None: load_in_8bit = device.type == "cuda" self.load_in_8bit = load_in_8bit - logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") + logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") - max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens - self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8 + cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens + self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: @@ -192,6 +198,7 @@ class Server: assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: throughput = get_server_throughput( + converted_model_name_or_path, self.block_config, device, torch_dtype, @@ -239,11 +246,12 @@ class Server: num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" + num_blocks = min(num_blocks, self.block_config.num_hidden_layers) logger.info( f"Server will fill all your GPU memory with {num_blocks} transformer blocks. " f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" ) - return min(num_blocks, self.block_config.n_layer) + return num_blocks def run(self): while True: @@ -274,6 +282,7 @@ class Server: step_timeout=self.step_timeout, prefetch_batches=self.prefetch_batches, sender_threads=self.sender_threads, + revision=self.revision, use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, @@ -352,7 +361,7 @@ class ModuleContainer(threading.Thread): dht: DHT, prefix: str, converted_model_name_or_path: str, - block_config: BloomConfig, + block_config: PretrainedConfig, attn_cache_bytes: int, alloc_timeout: float, throughput: float, @@ -366,6 +375,7 @@ class ModuleContainer(threading.Thread): compression: CompressionType, update_period: float, expiration: Optional[float], + revision: Optional[str], use_auth_token: Optional[str], load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], @@ -394,14 +404,14 @@ class ModuleContainer(threading.Thread): block = load_pretrained_block( converted_model_name_or_path, block_index, - block_config, + config=block_config, torch_dtype=torch_dtype, + revision=revision, use_auth_token=use_auth_token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) - blocks[module_uid] = TransformerBackend( module_uid, block, @@ -564,13 +574,9 @@ class ModuleContainer(threading.Thread): self.ready.clear() + logger.debug("Shutting down connection handlers") for handler in self.conn_handlers: handler.shutdown() - logger.debug("Connection handlers terminated") - - if self.checkpoint_saver is not None: - self.checkpoint_saver.stop.set() - self.checkpoint_saver.join() logger.debug(f"Shutting down pools") for pool in self.runtime.pools: diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index dbefb35..2ee1ca1 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -5,15 +5,13 @@ import multiprocessing as mp import os import time from collections import Counter -from hashlib import sha256 from pathlib import Path from typing import Dict, Optional, Sequence, Union import torch from hivemind.utils.logging import get_logger -from transformers import BloomConfig +from transformers import PretrainedConfig -from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_block import convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -35,7 +33,8 @@ if not hasattr(speedtest, "Speedtest"): def get_server_throughput( - config: BloomConfig, + model_name: str, + config: PretrainedConfig, device: torch.device, dtype: Union[str, torch.dtype], *, @@ -59,7 +58,7 @@ def get_server_throughput( fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # The OS will release the lock when lock_fd is closed or the process is killed - cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}" + cache_key = f"model_{model_name}" cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}" if len(tensor_parallel_devices) > 1: @@ -101,7 +100,7 @@ def get_server_throughput( def measure_throughput_info( - config: BloomConfig, + config: PretrainedConfig, device: torch.device, dtype: torch.dtype, *, @@ -127,7 +126,7 @@ def measure_throughput_info( return throughput_info -def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]: +def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]: pipe_recv, pipe_send = mp.Pipe(duplex=False) process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) process.start() @@ -160,7 +159,7 @@ def _measure_bits_per_second(pipe_send: mp.Pipe): def measure_compute_rps( - config: BloomConfig, + config: PretrainedConfig, device: torch.device, dtype: torch.dtype, *, @@ -172,7 +171,7 @@ def measure_compute_rps( if not tensor_parallel_devices: tensor_parallel_devices = (device,) with torch.inference_mode(): - block = WrappedBloomBlock(config).to(dtype) + block = config.block_class(config).to(dtype) block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) cache = None @@ -203,4 +202,7 @@ def get_device_name(device: torch.device) -> str: def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: - return "8-bit" if load_in_8bit else str(dtype) + name = str(dtype) + if load_in_8bit: + name += ", 8-bit quantized" + return name diff --git a/src/petals/utils/__init__.py b/src/petals/utils/__init__.py index e69de29..654e98c 100644 --- a/src/petals/utils/__init__.py +++ b/src/petals/utils/__init__.py @@ -0,0 +1 @@ +from petals.utils.auto_config import AutoDistributedConfig diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py new file mode 100644 index 0000000..b6fca41 --- /dev/null +++ b/src/petals/utils/auto_config.py @@ -0,0 +1,23 @@ +from typing import Type + +from transformers import AutoConfig, PretrainedConfig + +CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register() + + +class AutoDistributedConfig: + @classmethod + def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig: + config = AutoConfig.from_pretrained(*args, **kwargs) + if config.model_type not in CONFIG_MAPPING: + raise ValueError(f"Petals does not support model type {config.model_type}") + + dist_config_class = CONFIG_MAPPING[config.model_type] + return dist_config_class.from_pretrained(*args, **kwargs) + + @staticmethod + def register(config_class: Type[PretrainedConfig]) -> None: + assert issubclass(config_class, PretrainedConfig) + assert config_class.model_type not in CONFIG_MAPPING + + CONFIG_MAPPING[config_class.model_type] = config_class diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index b58cd1a..28aea56 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -10,18 +10,15 @@ import torch import torch.nn as nn from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config -from transformers import BloomConfig -from transformers.models.bloom.modeling_bloom import BloomAttention - -from petals.bloom.block import WrappedBloomBlock +from transformers import PretrainedConfig use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) def convert_block( - block: WrappedBloomBlock, - config: BloomConfig, + block: nn.Module, + config: PretrainedConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, load_in_8bit: bool, @@ -58,7 +55,7 @@ def convert_block( return block -def replace_8bit_linear(model: nn.Module, threshold=6.0): +def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module: """ A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): @@ -100,17 +97,22 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0): def make_tensor_parallel( - block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device -): - tp_config = get_bloom_config(model_config, devices) - del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] + block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device +) -> nn.Module: + if model_config.model_type == "bloom": + tp_config = get_bloom_config(model_config, devices) + del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] + else: + if len(devices) > 1: + logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution") + tp_config = None tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True) total_heads = 0 for tp_shard in tp_block.module_shards: for submodule in tp_shard.modules(): - if isinstance(submodule, BloomAttention): + if isinstance(submodule, model_config.attn_class): total_heads += submodule.num_heads - assert total_heads == model_config.n_head + assert total_heads == model_config.num_attention_heads return tp_block diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py index 3217e34..aefea1d 100644 --- a/src/petals/utils/disk_cache.py +++ b/src/petals/utils/disk_cache.py @@ -57,13 +57,16 @@ def free_disk_space_for( available_space = shutil.disk_usage(cache_dir).free - os_quota if max_disk_space is not None: available_space = min(available_space, max_disk_space - occupied_space) + + gib = 1024**3 + logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB") if size <= available_space: return revisions = [revision for repo in model_repos for revision in repo.revisions] revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified)) - # Remove as few least recently used blocks as possible + # Remove as few least recently used shards as possible pending_removal = [] freed_space = 0 extra_space_needed = size - available_space @@ -73,9 +76,8 @@ def free_disk_space_for( if freed_space >= extra_space_needed: break - gib = 1024**3 if pending_removal: - logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space") + logger.info(f"Removing {len(pending_removal)} shards to free {freed_space / gib:.1f} GiB of disk space") delete_strategy = cache_info.delete_revisions(*pending_removal) delete_strategy.execute() diff --git a/src/petals/utils/version.py b/src/petals/utils/version.py index f4a5be1..67b3866 100644 --- a/src/petals/utils/version.py +++ b/src/petals/utils/version.py @@ -1,3 +1,7 @@ +import os +import re +from typing import Union + import requests from hivemind.utils.logging import TextStyle, get_logger from packaging.version import parse @@ -7,7 +11,7 @@ import petals logger = get_logger(__name__) -def validate_version(): +def validate_version() -> None: logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}") try: r = requests.get("https://pypi.python.org/pypi/petals/json") @@ -24,3 +28,17 @@ def validate_version(): ) except Exception as e: logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True) + + +def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]: + if model_name_or_path is None: + return None + + match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path)) + if match is None: + return model_name_or_path + + logger.info( + f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones" + ) + return match.group(1) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 6909ccf..d42666b 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,7 +1,7 @@ import pytest import torch -from petals.client import DistributedBloomConfig +from petals import AutoDistributedConfig from petals.server.throughput import measure_compute_rps from test_utils import MODEL_NAME @@ -9,7 +9,7 @@ from test_utils import MODEL_NAME @pytest.mark.forked @pytest.mark.parametrize("tensor_parallel", [False, True]) def test_compute_throughput(tensor_parallel: bool): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( config, diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index a05387d..62c4e89 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -1,13 +1,10 @@ import random -from typing import Union import pytest import torch -from transformers.models.bloom.configuration_bloom import BloomConfig -from petals.bloom.block import WrappedBloomBlock -from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block -from petals.client import DistributedBloomConfig, RemoteSequential +from petals import DistributedBloomConfig, RemoteSequential +from petals.server.from_pretrained import load_pretrained_block from test_utils import * @@ -16,21 +13,22 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) - for block_index in random.sample(range(config.n_layer), 3): + for block_index in random.sample(range(config.num_hidden_layers), 3): remote_block = remote_sequential[block_index] inputs = torch.randn(1, 8, config.hidden_size) outputs_forward = remote_block(inputs) outputs_inference = [] - with remote_block.inference_session(max_length=inputs.shape[1]) as sess: - for i in range(inputs.shape[1]): - outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) - - # test that max length is respected - with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: - sess.step(inputs[:, -1:, :]) - assert "Maximum length exceeded" in repr(exc_info.value) + with torch.inference_mode(): + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: + for i in range(inputs.shape[1]): + outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + + # test that max length is respected + with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: + sess.step(inputs[:, -1:, :]) + assert "Maximum length exceeded" in repr(exc_info.value) outputs_inference = torch.cat(outputs_inference, dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) @@ -38,47 +36,3 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) - - -def _old_load_pretrained_block( - converted_model_name_or_path: str, - block_index: int, - torch_dtype: Union[torch.dtype, str] = "auto", -) -> WrappedBloomBlock: - """Load the BLOOM block by directly initializing the weights. - This test is used to check consistency with the previous implementation and can be removed in the future.""" - config = BloomConfig.from_pretrained(converted_model_name_or_path) - - block = WrappedBloomBlock(config) - state_dict = _load_state_dict( - converted_model_name_or_path, - block_index, - config, - cache_dir=None, - ) - - if torch_dtype == "auto": - with torch.no_grad(): - for name, param in block.named_parameters(): - assert name in state_dict, f"{name} not in state dict" - param.data = param.data.to(state_dict[name].dtype) - else: - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - block = block.to(dtype=torch_dtype) - - block.load_state_dict(state_dict, strict=True) - return block - - -@pytest.mark.forked -def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - torch.random.manual_seed(0) - inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype) - - block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) - ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) - - outputs = block.forward(inputs)[0] - outputs_ref = ref_block.forward(inputs)[0] - assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 15f3b5c..d20f654 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -7,9 +7,9 @@ import pytest import torch -from petals.bloom.from_pretrained import load_pretrained_block -from petals.client import DistributedBloomConfig +from petals import DistributedBloomConfig from petals.client.remote_sequential import RemoteSequential +from petals.server.from_pretrained import load_pretrained_block from test_utils import * diff --git a/tests/test_dtype.py b/tests/test_dtype.py index 03afd83..d102077 100644 --- a/tests/test_dtype.py +++ b/tests/test_dtype.py @@ -1,17 +1,16 @@ import pytest import torch -from petals.bloom.from_pretrained import load_pretrained_block -from petals.client import DistributedBloomConfig from petals.server.block_utils import resolve_block_dtype +from petals.server.from_pretrained import load_pretrained_block +from petals.utils.auto_config import AutoDistributedConfig from test_utils import MODEL_NAME @pytest.mark.forked @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"]) -def test_backend_dtype(torch_dtype): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype) - backend_dtype = resolve_block_dtype(config, torch_dtype) - other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype - assert backend_dtype == other_backend_dtype +def test_block_dtype(torch_dtype): + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) + block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype) + expected_dtype = resolve_block_dtype(config, torch_dtype) + assert all(param.dtype == expected_dtype for param in block.parameters()) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index cef002e..f2679f2 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -5,7 +5,7 @@ from hivemind import get_logger from transformers.generation import BeamSearchScorer from transformers.models.bloom import BloomForCausalLM -from petals.client.remote_model import DistributedBloomForCausalLM +from petals import DistributedBloomForCausalLM from test_utils import * logger = get_logger(__name__) @@ -20,7 +20,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato ) config = model.config assert isinstance(model, DistributedBloomForCausalLM) - assert len(model.transformer.h) == model.config.n_layer + assert len(model.transformer.h) == model.config.num_hidden_layers test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index d46ca1c..734683f 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -4,10 +4,10 @@ import torch.nn.functional as F from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind.proto import runtime_pb2 -from petals.bloom.from_pretrained import load_pretrained_block +from petals import DistributedBloomConfig from petals.client import RemoteSequenceManager, RemoteSequential -from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER +from petals.server.from_pretrained import load_pretrained_block from test_utils import * logger = get_logger(__name__) @@ -28,10 +28,10 @@ def test_remote_sequential(): full_grad = test_inputs.grad.clone() test_inputs.grad.data.zero_() - first_half = sequential[: config.n_layer // 2] - second_half = sequential[config.n_layer // 2 :] + first_half = sequential[: config.num_hidden_layers // 2] + second_half = sequential[config.num_hidden_layers // 2 :] assert len(first_half) + len(second_half) == len(sequential) - assert abs(len(first_half) - len(second_half)) == config.n_layer % 2 + assert abs(len(first_half) - len(second_half)) == config.num_hidden_layers % 2 for m in sequential, first_half, second_half: assert isinstance(repr(m), str) @@ -46,7 +46,7 @@ def test_remote_sequential(): assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3) # test RemoteSequential with lossy compression - block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] lossy_sequential = RemoteSequential( config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht) ) @@ -90,7 +90,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1) input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1) - intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True) + intermediate_prompts = torch.randn( + config.num_hidden_layers, batch_size, pre_seq_len, config.hidden_size, requires_grad=True + ) input_prompts = input_prompts.detach().requires_grad_(True) intermediate_prompts = intermediate_prompts.detach().requires_grad_(True) @@ -110,7 +112,7 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): assert intermediate_prompts_ref.grad is None outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1) - for block_index in range(config.n_layer): + for block_index in range(config.num_hidden_layers): block_prompt = intermediate_prompts_ref[block_index] outputs_ref[:, : block_prompt.shape[1]] += block_prompt diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 38e9a8a..86d04ca 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -5,8 +5,8 @@ import pytest import torch from hivemind import DHT, get_logger +from petals import DistributedBloomConfig from petals.client import RemoteSequenceManager, RemoteSequential -from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER from test_utils import * @@ -22,7 +22,7 @@ def test_sequence_manager_basics(mode: str): shutdown_evt = threading.Event() # test RemoteSequential with lossy compression - block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] sequential = RemoteSequential( config, sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt), diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 11d2565..5de3393 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -4,7 +4,7 @@ import hivemind import pytest import torch -from petals.client import DistributedBloomConfig, RemoteSequential +from petals import DistributedBloomConfig, RemoteSequential from petals.server.handler import CACHE_TOKENS_AVAILABLE from test_utils import * diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 84fcab4..408a261 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -6,7 +6,7 @@ import transformers from tensor_parallel import TensorParallel from tensor_parallel.slicing_configs import get_bloom_config -from petals.bloom.from_pretrained import load_pretrained_block +from petals.server.from_pretrained import load_pretrained_block from test_utils import MODEL_NAME From 7a37513f77091a941e379caccc1b77aad13d502a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 23 Jun 2023 18:42:50 +0400 Subject: [PATCH 091/168] Add AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification} (#329) This PR adds `petals.AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification}` classes, similar to their `transformers.Auto{Model, ModelForCausalLM, ModelForSequenceClassification}` counterparts. --- src/petals/models/bloom/__init__.py | 8 +++++ src/petals/models/bloom/config.py | 3 -- src/petals/models/llama/__init__.py | 8 +++++ src/petals/models/llama/config.py | 3 -- src/petals/utils/__init__.py | 7 +++- src/petals/utils/auto_config.py | 55 ++++++++++++++++++++++------- 6 files changed, 65 insertions(+), 19 deletions(-) diff --git a/src/petals/models/bloom/__init__.py b/src/petals/models/bloom/__init__.py index 911974b..2932701 100644 --- a/src/petals/models/bloom/__init__.py +++ b/src/petals/models/bloom/__init__.py @@ -5,3 +5,11 @@ from petals.models.bloom.model import ( DistributedBloomForSequenceClassification, DistributedBloomModel, ) +from petals.utils.auto_config import register_model_classes + +register_model_classes( + config=DistributedBloomConfig, + model=DistributedBloomModel, + model_for_causal_lm=DistributedBloomForCausalLM, + model_for_sequence_classification=DistributedBloomForSequenceClassification, +) diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py index 57c3e7b..a376aab 100644 --- a/src/petals/models/bloom/config.py +++ b/src/petals/models/bloom/config.py @@ -30,6 +30,3 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM dht_prefix = str(model_name_or_path) + "-petals" logger.info(f"Using DHT prefix: {dht_prefix}") return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) - - -AutoDistributedConfig.register(DistributedBloomConfig) diff --git a/src/petals/models/llama/__init__.py b/src/petals/models/llama/__init__.py index 8156939..e5d8aa4 100644 --- a/src/petals/models/llama/__init__.py +++ b/src/petals/models/llama/__init__.py @@ -5,3 +5,11 @@ from petals.models.llama.model import ( DistributedLlamaForSequenceClassification, DistributedLlamaModel, ) +from petals.utils.auto_config import register_model_classes + +register_model_classes( + config=DistributedLlamaConfig, + model=DistributedLlamaModel, + model_for_causal_lm=DistributedLlamaForCausalLM, + model_for_sequence_classification=DistributedLlamaForSequenceClassification, +) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index a7e6681..f5dc6f6 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -30,6 +30,3 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :] logger.info(f"Using DHT prefix: {dht_prefix}") return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) - - -AutoDistributedConfig.register(DistributedLlamaConfig) diff --git a/src/petals/utils/__init__.py b/src/petals/utils/__init__.py index 654e98c..0852074 100644 --- a/src/petals/utils/__init__.py +++ b/src/petals/utils/__init__.py @@ -1 +1,6 @@ -from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.auto_config import ( + AutoDistributedConfig, + AutoDistributedModel, + AutoDistributedModelForCausalLM, + AutoDistributedModelForSequenceClassification, +) diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index b6fca41..f587051 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -1,23 +1,54 @@ -from typing import Type +from dataclasses import dataclass +from typing import Optional, Type -from transformers import AutoConfig, PretrainedConfig +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel -CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register() +@dataclass +class _ModelClasses: + config: Type[PretrainedConfig] + model: Optional[Type[PreTrainedModel]] = None + model_for_causal_lm: Optional[Type[PreTrainedModel]] = None + model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None + + +_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes() + + +def register_model_classes(*, config: Type[PretrainedConfig], **kwargs): + assert issubclass(config, PretrainedConfig) + assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered" + + _CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs) + + +class _AutoDistributedBase: + _mapping_field = None # Should be defined in child classes -class AutoDistributedConfig: @classmethod def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig: config = AutoConfig.from_pretrained(*args, **kwargs) - if config.model_type not in CONFIG_MAPPING: + if config.model_type not in _CLASS_MAPPING: raise ValueError(f"Petals does not support model type {config.model_type}") - dist_config_class = CONFIG_MAPPING[config.model_type] - return dist_config_class.from_pretrained(*args, **kwargs) + proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field) + if proper_cls is None: + raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}") + + return proper_cls.from_pretrained(*args, **kwargs) + + +class AutoDistributedConfig(_AutoDistributedBase): + _mapping_field = "config" + + +class AutoDistributedModel(_AutoDistributedBase): + _mapping_field = "model" + + +class AutoDistributedModelForCausalLM(_AutoDistributedBase): + _mapping_field = "model_for_causal_lm" - @staticmethod - def register(config_class: Type[PretrainedConfig]) -> None: - assert issubclass(config_class, PretrainedConfig) - assert config_class.model_type not in CONFIG_MAPPING - CONFIG_MAPPING[config_class.model_type] = config_class +class AutoDistributedModelForSequenceClassification(_AutoDistributedBase): + _mapping_field = "model_for_sequence_classification" From 47a2b1ee65a6fe2cabc3e262c116c04149fe231b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 24 Jun 2023 02:30:13 +0400 Subject: [PATCH 092/168] Fix llama's lm_head.weight.requires_grad (#330) By default, `llama's lm_head.weight.requires_grad` was True, but we expect it to be False. --- src/petals/client/lm_head.py | 3 ++- src/petals/client/ptune.py | 4 ---- src/petals/models/bloom/model.py | 2 +- src/petals/models/llama/model.py | 2 +- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/petals/client/lm_head.py b/src/petals/client/lm_head.py index ddd2887..938d6da 100644 --- a/src/petals/client/lm_head.py +++ b/src/petals/client/lm_head.py @@ -26,7 +26,8 @@ class LMHead(nn.Module): super().__init__() if not config.tie_word_embeddings: - self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False)) + self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size)) + self.weight.requires_grad = False else: self.weight = None # Will be set to get_input_embeddings().weight during loading the model self.bias = None diff --git a/src/petals/client/ptune.py b/src/petals/client/ptune.py index 5cf613c..684cc23 100644 --- a/src/petals/client/ptune.py +++ b/src/petals/client/ptune.py @@ -40,10 +40,6 @@ class PTuneMixin: elif config.tuning_mode: raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") - def set_requires_grad(self, value): - for p in self.parameters(): - p.requires_grad = value - def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index fae9faf..e4961d3 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -35,7 +35,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): self.h = RemoteSequential(config, dht=dht) - self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward( diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index 37b4683..244207b 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -33,7 +33,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): self.layers = RemoteSequential(config, dht=dht) - self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward( From fecee8c4dcfde4afa576dde009e6d3f623b4c034 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 24 Jun 2023 20:19:18 +0400 Subject: [PATCH 093/168] Show license links when loading models (#332) --- src/petals/models/bloom/config.py | 2 ++ src/petals/models/llama/config.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py index a376aab..d6a8146 100644 --- a/src/petals/models/bloom/config.py +++ b/src/petals/models/bloom/config.py @@ -24,6 +24,8 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): + logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license") + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: # We need "-petals" for backward compatibility with Petals < 1.2.0 diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index f5dc6f6..dd5f6b1 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -23,6 +23,11 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): + logger.info( + "LLaMA is available solely for non-commercial research purposes. " + "Make sure you follow the terms of use: https://bit.ly/llama-license" + ) + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) From d126ee3053715ae09d648787fb7635d4c6675f7f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 30 Jun 2023 01:12:59 +0400 Subject: [PATCH 094/168] Add benchmark scripts (#319) This PR: - Adds benchmark scripts for inference, forward pass, and full training step (e.g. used for experiments in our paper). - Fixes bug with dtypes in `petals.DistributedBloomForSequenceClassification`. - (minor refactor) Moves `DTYPE_MAP` to `petals.constants` as a useful constant. --- benchmarks/benchmark_forward.py | 69 ++++++++++++++++++ benchmarks/benchmark_inference.py | 64 +++++++++++++++++ benchmarks/benchmark_training.py | 101 +++++++++++++++++++++++++++ src/petals/cli/run_server.py | 4 +- src/petals/constants.py | 4 ++ src/petals/models/bloom/model.py | 2 +- src/petals/server/from_pretrained.py | 4 +- src/petals/server/server.py | 4 +- 8 files changed, 244 insertions(+), 8 deletions(-) create mode 100755 benchmarks/benchmark_forward.py create mode 100755 benchmarks/benchmark_inference.py create mode 100755 benchmarks/benchmark_training.py diff --git a/benchmarks/benchmark_forward.py b/benchmarks/benchmark_forward.py new file mode 100755 index 0000000..0a7d4f8 --- /dev/null +++ b/benchmarks/benchmark_forward.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +import argparse +import multiprocessing as mp +from time import perf_counter + +import torch +from hivemind.utils.logging import get_logger + +from petals import AutoDistributedModel +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS + +logger = get_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="bigscience/bloom") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) + parser.add_argument("--torch_dtype", type=str, default="bfloat16") + parser.add_argument("--n_processes", type=str, default=1) + parser.add_argument("--seq_len", type=int, default=128) + parser.add_argument("--n_steps", type=int, default=100) + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--warmup_steps", type=int, default=1) + args = parser.parse_args() + + if args.n_processes == "n_gpus": + args.n_processes = torch.cuda.device_count() + else: + args.n_processes = int(args.n_processes) + + processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)] + for proc in processes: + proc.start() + for proc in processes: + proc.join() + + +@torch.inference_mode() +def benchmark_forward(process_idx, args): + model = AutoDistributedModel.from_pretrained( + args.model, + initial_peers=args.initial_peers, + torch_dtype=DTYPE_MAP[args.torch_dtype], + ) + logger.info(f"Created model: {process_idx=} {model.device=}") + + torch.manual_seed(42) + for step in range(args.n_steps): + if step == args.warmup_steps: + start_time = perf_counter() + + input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len)) + + logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}") + h = model(input_ids) + # We don't use model.lm_head + logger.info(f"{process_idx=} Fwd end") + + if step >= args.warmup_steps: + speed = step / (perf_counter() - start_time) * input_ids.numel() + logger.info(f"{process_idx=} {step=} {speed=:.3f}") + + logger.info(f"Final result: {process_idx=} {speed=:.3f}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_inference.py b/benchmarks/benchmark_inference.py new file mode 100755 index 0000000..7b5f0e1 --- /dev/null +++ b/benchmarks/benchmark_inference.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 + +import argparse +import multiprocessing as mp +from time import perf_counter + +import torch +from hivemind.utils.logging import get_logger +from transformers import AutoTokenizer + +from petals import AutoDistributedModelForCausalLM +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS + +logger = get_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="bigscience/bloom") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) + parser.add_argument("--torch_dtype", type=str, default="bfloat16") + parser.add_argument("--n_processes", type=str, default=1) + parser.add_argument("--seq_len", type=int, default=2048) + parser.add_argument("--warmup_steps", type=int, default=1) + args = parser.parse_args() + + if args.n_processes == "n_gpus": + args.n_processes = torch.cuda.device_count() + else: + args.n_processes = int(args.n_processes) + + processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)] + for proc in processes: + proc.start() + for proc in processes: + proc.join() + + +@torch.inference_mode() +def benchmark_inference(process_idx, args): + tokenizer = AutoTokenizer.from_pretrained(args.model) + model = AutoDistributedModelForCausalLM.from_pretrained( + args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype] + ) + logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}") + + result = "" + with model.transformer.h.inference_session(max_length=args.seq_len) as sess: + for step in range(args.seq_len): + if step == args.warmup_steps: + start_time = perf_counter() + + outputs = model.generate(max_new_tokens=1, session=sess) + result += tokenizer.decode(outputs[0]) + + if step >= args.warmup_steps: + speed = step / (perf_counter() - start_time) + logger.info(f"{process_idx=} {step=} {speed=:.3f}") + + logger.info(f"Final result: {process_idx=} {speed=:.3f}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_training.py b/benchmarks/benchmark_training.py new file mode 100755 index 0000000..46d0eb2 --- /dev/null +++ b/benchmarks/benchmark_training.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +import argparse +import multiprocessing as mp +from time import perf_counter + +import numpy as np +import torch +from hivemind.utils.logging import get_logger + +from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS + +logger = get_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="bigscience/bloom") + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--task", type=str, default="cls") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) + parser.add_argument("--torch_dtype", type=str, default="bfloat16") + parser.add_argument("--n_processes", type=str, default=1) + parser.add_argument("--seq_len", type=int, default=128) + parser.add_argument("--pre_seq_len", type=int, default=16) + parser.add_argument("--n_steps", type=int, default=10) + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--warmup_steps", type=int, default=1) + args = parser.parse_args() + + assert args.task in ["cls", "causal_lm"] + + if args.n_processes == "n_gpus": + args.n_processes = torch.cuda.device_count() + else: + args.n_processes = int(args.n_processes) + + processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)] + for proc in processes: + proc.start() + for proc in processes: + proc.join() + + +def benchmark_training(process_idx, args): + if args.task == "cls": + model = AutoDistributedModelForSequenceClassification.from_pretrained( + args.model, + initial_peers=args.initial_peers, + torch_dtype=DTYPE_MAP[args.torch_dtype], + tuning_mode="deep_ptune", + pre_seq_len=args.pre_seq_len, + num_labels=2, + ) + elif args.task == "causal_lm": + model = AutoDistributedModelForCausalLM.from_pretrained( + args.model, + initial_peers=args.initial_peers, + torch_dtype=DTYPE_MAP[args.torch_dtype], + tuning_mode="deep_ptune", + pre_seq_len=args.pre_seq_len, + ) + model = model.to(args.device) + opt = torch.optim.Adam(model.parameters()) + logger.info(f"Created model: {process_idx=} {model.device=}") + + torch.manual_seed(42) + fwd_times = [] + bwd_times = [] + for step in range(args.n_steps): + input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device) + if args.task == "cls": + labels = torch.randint(0, 2, size=[args.batch_size], device=args.device) + else: + labels = input_ids + + logger.info(f"{process_idx=} {step=} Forward") + start_time = perf_counter() + outputs = model(input_ids, labels=labels) + fwd_times.append(perf_counter() - start_time) + + logger.info(f"{process_idx=} {step=} Backward") + start_time = perf_counter() + outputs.loss.backward() + bwd_times.append(perf_counter() - start_time) + + logger.info(f"{process_idx=} {step=} Optimizer step") + opt.step() + opt.zero_grad() + + if step >= args.warmup_steps: + fwd_speed = input_ids.numel() / np.mean(fwd_times[1:]) + bwd_speed = input_ids.numel() / np.mean(bwd_times[1:]) + logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}") + + logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}") + + +if __name__ == "__main__": + main() diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 4c6f0e5..83e35e5 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -6,8 +6,8 @@ from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger from humanfriendly import parse_size -from petals.constants import PUBLIC_INITIAL_PEERS -from petals.server.server import DTYPE_MAP, Server +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS +from petals.server.server import Server from petals.utils.version import validate_version logger = get_logger(__name__) diff --git a/src/petals/constants.py b/src/petals/constants.py index da047f1..b04ad03 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -1,3 +1,5 @@ +import torch + PUBLIC_INITIAL_PEERS = [ "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", @@ -7,3 +9,5 @@ PUBLIC_INITIAL_PEERS = [ # The reachability API is currently used only when connecting to the public swarm REACHABILITY_API_URL = "http://health.petals.ml" + +DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index e4961d3..7644148 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -128,7 +128,7 @@ class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSeq self.num_labels = config.num_labels self.transformer = DistributedBloomModel(config) - self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index aab8a9e..62b9959 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -19,6 +19,7 @@ from huggingface_hub import get_hf_file_metadata, hf_hub_url from transformers import PretrainedConfig from transformers.utils import get_file_from_repo +from petals.constants import DTYPE_MAP from petals.server.block_utils import resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for @@ -170,6 +171,3 @@ def _load_state_dict_from_file( except Exception as e: logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) time.sleep(delay) - - -DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 75a999e..39c432c 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -16,13 +16,13 @@ from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig -from petals.constants import PUBLIC_INITIAL_PEERS +from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype -from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block +from petals.server.from_pretrained import load_pretrained_block from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability From 10c72acdf4bf17f8f4b6405784310fc6386c412a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 30 Jun 2023 04:18:43 +0400 Subject: [PATCH 095/168] Fix warmup steps and minor issues in benchmarks (#334) The previous code was incorrect for the case of `warmup_steps != 1` (this mode was never used, but can be used in future). --- benchmarks/benchmark_forward.py | 14 ++++++++------ benchmarks/benchmark_inference.py | 18 +++++++++++------- benchmarks/benchmark_training.py | 12 +++++++----- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/benchmarks/benchmark_forward.py b/benchmarks/benchmark_forward.py index 0a7d4f8..e95c5ec 100755 --- a/benchmarks/benchmark_forward.py +++ b/benchmarks/benchmark_forward.py @@ -4,6 +4,7 @@ import argparse import multiprocessing as mp from time import perf_counter +import numpy as np import torch from hivemind.utils.logging import get_logger @@ -47,9 +48,9 @@ def benchmark_forward(process_idx, args): logger.info(f"Created model: {process_idx=} {model.device=}") torch.manual_seed(42) - for step in range(args.n_steps): - if step == args.warmup_steps: - start_time = perf_counter() + step_times = [] + for step in range(args.warmup_steps + args.n_steps): + start_time = perf_counter() input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len)) @@ -59,10 +60,11 @@ def benchmark_forward(process_idx, args): logger.info(f"{process_idx=} Fwd end") if step >= args.warmup_steps: - speed = step / (perf_counter() - start_time) * input_ids.numel() - logger.info(f"{process_idx=} {step=} {speed=:.3f}") + step_times.append(perf_counter() - start_time) + speed = input_ids.numel() / np.mean(step_times) + logger.info(f"{process_idx=} {step=} {speed=:.2f}") - logger.info(f"Final result: {process_idx=} {speed=:.3f}") + logger.info(f"Final result: {process_idx=} {speed=:.2f}") if __name__ == "__main__": diff --git a/benchmarks/benchmark_inference.py b/benchmarks/benchmark_inference.py index 7b5f0e1..607ff88 100755 --- a/benchmarks/benchmark_inference.py +++ b/benchmarks/benchmark_inference.py @@ -4,6 +4,7 @@ import argparse import multiprocessing as mp from time import perf_counter +import numpy as np import torch from hivemind.utils.logging import get_logger from transformers import AutoTokenizer @@ -38,26 +39,29 @@ def main(): @torch.inference_mode() def benchmark_inference(process_idx, args): - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) + # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway + model = AutoDistributedModelForCausalLM.from_pretrained( args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype] ) - logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}") + logger.info(f"Created model: {process_idx=} {model.device=}") result = "" + step_times = [] with model.transformer.h.inference_session(max_length=args.seq_len) as sess: for step in range(args.seq_len): - if step == args.warmup_steps: - start_time = perf_counter() + start_time = perf_counter() outputs = model.generate(max_new_tokens=1, session=sess) result += tokenizer.decode(outputs[0]) if step >= args.warmup_steps: - speed = step / (perf_counter() - start_time) - logger.info(f"{process_idx=} {step=} {speed=:.3f}") + step_times.append(perf_counter() - start_time) + speed = 1 / np.mean(step_times) + logger.info(f"{process_idx=} {step=} {speed=:.2f}") - logger.info(f"Final result: {process_idx=} {speed=:.3f}") + logger.info(f"Final result: {process_idx=} {speed=:.2f}") if __name__ == "__main__": diff --git a/benchmarks/benchmark_training.py b/benchmarks/benchmark_training.py index 46d0eb2..0853dfc 100755 --- a/benchmarks/benchmark_training.py +++ b/benchmarks/benchmark_training.py @@ -68,7 +68,7 @@ def benchmark_training(process_idx, args): torch.manual_seed(42) fwd_times = [] bwd_times = [] - for step in range(args.n_steps): + for step in range(args.warmup_steps + args.n_steps): input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device) if args.task == "cls": labels = torch.randint(0, 2, size=[args.batch_size], device=args.device) @@ -78,20 +78,22 @@ def benchmark_training(process_idx, args): logger.info(f"{process_idx=} {step=} Forward") start_time = perf_counter() outputs = model(input_ids, labels=labels) - fwd_times.append(perf_counter() - start_time) + if step >= args.warmup_steps: + fwd_times.append(perf_counter() - start_time) logger.info(f"{process_idx=} {step=} Backward") start_time = perf_counter() outputs.loss.backward() - bwd_times.append(perf_counter() - start_time) + if step >= args.warmup_steps: + bwd_times.append(perf_counter() - start_time) logger.info(f"{process_idx=} {step=} Optimizer step") opt.step() opt.zero_grad() if step >= args.warmup_steps: - fwd_speed = input_ids.numel() / np.mean(fwd_times[1:]) - bwd_speed = input_ids.numel() / np.mean(bwd_times[1:]) + fwd_speed = input_ids.numel() / np.mean(fwd_times) + bwd_speed = input_ids.numel() / np.mean(bwd_times) logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}") logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}") From 66a47c763efc0bb1b49208a195421bc411db5338 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 2 Jul 2023 03:32:51 +0400 Subject: [PATCH 096/168] Require pydantic < 2.0 (2.0 is incompatible with hivemind 1.1.8) (#337) See https://github.com/learning-at-home/hivemind/pull/573. --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 4722c63..eacd99a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ install_requires = tokenizers>=0.13.3 transformers>=4.30.1,<5.0.0 speedtest-cli==2.1.3 + pydantic>=1.8.1,<2.0 # 2.0 is incompatible with hivemind==1.1.8 hivemind==1.1.8 tensor_parallel==1.0.23 humanfriendly From de930918a0743da011caeccd7131f53278a1e8ae Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 3 Jul 2023 20:13:04 +0400 Subject: [PATCH 097/168] Support loading blocks in 4-bit (QLoRA NF4 format, disabled by default) (#333) --- setup.cfg | 2 +- src/petals/cli/run_server.py | 14 +++--- src/petals/server/block_utils.py | 27 ++++++----- src/petals/server/server.py | 30 ++++++------ src/petals/server/throughput.py | 24 +++++----- src/petals/utils/convert_block.py | 78 +++++++++++++++++-------------- tests/test_aux_functions.py | 3 +- tests/test_remote_sequential.py | 5 +- 8 files changed, 100 insertions(+), 83 deletions(-) diff --git a/setup.cfg b/setup.cfg index eacd99a..fb1fa23 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.38.0.post2 + bitsandbytes==0.39.1 accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 83e35e5..3c28709 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -8,6 +8,7 @@ from humanfriendly import parse_size from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.server.server import Server +from petals.utils.convert_block import QuantType from petals.utils.version import validate_version logger = get_logger(__name__) @@ -133,9 +134,10 @@ def main(): help="Check the swarm's balance every N seconds (and rebalance it if necessary)") parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained") - parser.add_argument('--load_in_8bit', type=str, default=None, - help="Convert the loaded transformer blocks into mixed-8bit quantized model. " - "Default: True if GPU is available. Use `--load_in_8bit False` to disable this") + parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType], + help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or " + "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. " + "Default: 'int8' if GPU is available, 'none' otherwise") parser.add_argument("--tensor_parallel_devices", nargs='+', default=None, help= "Split each block between the specified GPUs such that each device holds a portion of every " @@ -186,9 +188,9 @@ def main(): if args.pop("new_swarm"): args["initial_peers"] = [] - load_in_8bit = args.pop("load_in_8bit") - if load_in_8bit is not None: - args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"] + quant_type = args.pop("quant_type") + if quant_type is not None: + args["quant_type"] = QuantType[quant_type.upper()] validate_version() diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index a6af3b0..eb5300e 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -4,6 +4,8 @@ import torch from accelerate import init_empty_weights from transformers import PretrainedConfig +from petals.utils.convert_block import QuantType + def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" @@ -19,27 +21,30 @@ def get_block_size( location: str, *, dtype: Optional[Union[str, torch.dtype]] = None, - load_in_8bit: Optional[bool] = None, + quant_type: QuantType = QuantType.NONE, eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc. ) -> int: if location == "memory": assert ( - dtype is not None and load_in_8bit is not None - ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' + dtype is not None and quant_type is not None + ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations' with init_empty_weights(include_buffers=True): block = config.block_class(config) n_params = sum(param.numel() for param in block.parameters()) - if location == "memory" and load_in_8bit: - # Note: We may need a larger eps here for models of size < 1B - return n_params * (1 + eps) - if location == "memory": - dtype = resolve_block_dtype(config, dtype) + if quant_type == QuantType.NONE: + dtype = resolve_block_dtype(config, dtype) + bytes_per_value = torch.finfo(dtype).bits // 8 + elif quant_type == QuantType.INT8: + bytes_per_value = 1 + elif quant_type == QuantType.NF4: + bytes_per_value = 4.25 / 8 # Bitness of NF4 with this config (measured empirically) + else: + raise ValueError(f"Unsupported quant_type={quant_type}") elif location == "disk": dtype = resolve_block_dtype(config, "auto") - else: - raise ValueError('get_block_size() expects location to be "memory" or "disk"') + bytes_per_value = torch.finfo(dtype).bits // 8 - return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps)) + return round(n_params * bytes_per_value * (1 + eps)) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 39c432c..2fbaad2 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -28,7 +28,7 @@ from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig -from petals.utils.convert_block import check_device_balance, convert_block +from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.version import get_compatible_model_repo @@ -75,7 +75,7 @@ class Server: mean_balance_check_period: float = 120, mean_block_selection_delay: float = 2.5, use_auth_token: Optional[str] = None, - load_in_8bit: Optional[bool] = None, + quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, dht_client_mode: Optional[bool] = None, @@ -154,8 +154,8 @@ class Server: device = torch.device(device.type, index=0) self.device = device - torch_dtype = DTYPE_MAP[torch_dtype] - self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype) + torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype]) + self.torch_dtype = torch_dtype if tensor_parallel_devices is None: tensor_parallel_devices = (device,) @@ -164,10 +164,10 @@ class Server: logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}") check_device_balance(self.tensor_parallel_devices) - if load_in_8bit is None: - load_in_8bit = device.type == "cuda" - self.load_in_8bit = load_in_8bit - logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") + if quant_type is None: + quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE + self.quant_type = quant_type + logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 @@ -203,7 +203,7 @@ class Server: device, torch_dtype, num_blocks=num_blocks, - load_in_8bit=load_in_8bit, + quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, force_eval=(throughput == "eval"), cache_dir=cache_dir, @@ -237,11 +237,11 @@ class Server: else: total_memory = torch.cuda.get_device_properties(self.device).total_memory - block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit) + block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type) - # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models gib = 1024**3 - autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward + # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) + autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" @@ -284,7 +284,7 @@ class Server: sender_threads=self.sender_threads, revision=self.revision, use_auth_token=self.use_auth_token, - load_in_8bit=self.load_in_8bit, + quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, start=True, @@ -377,7 +377,7 @@ class ModuleContainer(threading.Thread): expiration: Optional[float], revision: Optional[str], use_auth_token: Optional[str], - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, **kwargs, @@ -411,7 +411,7 @@ class ModuleContainer(threading.Thread): cache_dir=cache_dir, max_disk_space=max_disk_space, ) - block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) + block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True) blocks[module_uid] = TransformerBackend( module_uid, block, diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 2ee1ca1..76bbc85 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -13,7 +13,7 @@ from hivemind.utils.logging import get_logger from transformers import PretrainedConfig from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_block import convert_block +from petals.utils.convert_block import QuantType, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__name__) @@ -39,7 +39,7 @@ def get_server_throughput( dtype: Union[str, torch.dtype], *, num_blocks: int, - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, cache_dir: Optional[str] = None, @@ -60,7 +60,7 @@ def get_server_throughput( cache_key = f"model_{model_name}" cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" - cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}" + cache_key += f"_dtype_{get_dtype_name(dtype, quant_type)}" if len(tensor_parallel_devices) > 1: for i, device_i in enumerate(tensor_parallel_devices): cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}" @@ -77,7 +77,7 @@ def get_server_throughput( if cache_key not in cache: cache[cache_key] = measure_throughput_info( - config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices ) try: @@ -104,7 +104,7 @@ def measure_throughput_info( device: torch.device, dtype: torch.dtype, *, - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], ) -> Dict[str, float]: """Measure network and compute throughput in forward pass tokens per second""" @@ -115,7 +115,7 @@ def measure_throughput_info( throughput_info = { "compute_rps": measure_compute_rps( - config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices ) } try: @@ -163,7 +163,7 @@ def measure_compute_rps( device: torch.device, dtype: torch.dtype, *, - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], n_tokens: int = 16, n_steps: int = 500, @@ -172,7 +172,7 @@ def measure_compute_rps( tensor_parallel_devices = (device,) with torch.inference_mode(): block = config.block_class(config).to(dtype) - block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) + block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) cache = None elapsed = 0 @@ -192,7 +192,7 @@ def measure_compute_rps( logger.info( f"Forward pass throughput: {device_rps:.1f} RPS per block " - f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})" + f"({devices_repr}, {get_dtype_name(dtype, quant_type)})" ) return device_rps @@ -201,8 +201,8 @@ def get_device_name(device: torch.device) -> str: return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU" -def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: +def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: name = str(dtype) - if load_in_8bit: - name += ", 8-bit quantized" + if quant_type != QuantType.NONE: + name += f", quantized to {quant_type.name.lower()}" return name diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 28aea56..6b129f5 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -3,6 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par """ import os import re +from enum import Enum from typing import Sequence import tensor_parallel as tp @@ -16,13 +17,18 @@ use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) +class QuantType(Enum): + NONE = 0 + INT8 = 1 # 8-bit as in the LLM.int8() paper + NF4 = 2 # 4-bit as in the QLoRA paper + + def convert_block( block: nn.Module, config: PretrainedConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, - load_in_8bit: bool, - threshold: float = 6.0, + quant_type: QuantType, freeze: bool = True, ) -> tp.TensorParallel: """ @@ -34,20 +40,18 @@ def convert_block( :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity) :param output_device: if tensor_parallel_devices is True, output - :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint - :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 ) + :param quant_type: quantization type :param freeze: if True (default), make all module parameters non-trainable :return: a module that acts like the original block, but runs with all specified optimizations """ if freeze: - for param in block.parameters(): - param.requires_grad = False + block.requires_grad_(False) block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) - if load_in_8bit: - block = replace_8bit_linear(block, threshold=threshold) + if quant_type != QuantType.NONE: + block = quantize_module(block, quant_type=quant_type) for shard, device in zip(block.module_shards, block.devices): shard.to(device) @@ -55,43 +59,45 @@ def convert_block( return block -def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module: - """ - A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` - library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): - 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA - version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ - bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) - The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should - be kept as a `torch.nn.Linear` module. - Parameters: - model (`torch.nn.Module`): - Input model or `torch.nn.Module` as the function is run recursively. - threshold (`float`, *optional*): - `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to - `6.0` as described by the paper. - """ - +def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module: # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes os.environ["BITSANDBYTES_NOWELCOME"] = "1" import bitsandbytes as bnb for n, module in model.named_children(): if len(list(module.children())) > 0: - replace_8bit_linear(module, threshold) + quantize_module(module, quant_type=quant_type) if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}" - model._modules[n] = bnb.nn.Linear8bitLt( - module.in_features, - module.out_features, - module.bias is not None, - has_fp16_weights=False, - threshold=threshold, - ) - model._modules[n].weight = bnb.nn.Int8Params( - module.weight.data, requires_grad=False, has_fp16_weights=False - ).to(module.weight.dtype) + if quant_type == QuantType.INT8: + model._modules[n] = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=6.0, # Default from the LLM.int8() paper + ) + model._modules[n].weight = bnb.nn.Int8Params( + module.weight.data, requires_grad=False, has_fp16_weights=False + ).to(module.weight.dtype) + elif quant_type == QuantType.NF4: + compress_statistics = True + model._modules[n] = bnb.nn.LinearNF4( + module.in_features, + module.out_features, + module.bias is not None, + compress_statistics=compress_statistics, + ) + model._modules[n].weight = bnb.nn.Params4bit( + module.weight.data, + requires_grad=False, + quant_type="nf4", + blocksize=64, + compress_statistics=compress_statistics, + ).to(module.weight.dtype) + else: + raise ValueError(f"Unsupported quant_type='{quant_type}'") model._modules[n].bias = module.bias return model diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index d42666b..5fa14db 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -3,6 +3,7 @@ import torch from petals import AutoDistributedConfig from petals.server.throughput import measure_compute_rps +from petals.utils.convert_block import QuantType from test_utils import MODEL_NAME @@ -15,7 +16,7 @@ def test_compute_throughput(tensor_parallel: bool): config, device=torch.device("cpu"), dtype=torch.bfloat16, - load_in_8bit=False, + quant_type=QuantType.NONE, tensor_parallel_devices=tensor_parallel_devices, n_steps=10, ) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 734683f..3c8a48f 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -78,7 +78,10 @@ class DummyCustomSequenceManager(RemoteSequenceManager): if protocol == "rpc_forward": metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) elif protocol == "rpc_backward": - metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,) + metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) + # FIXME: Initially, we used CompressionType.BLOCKWISE_8BIT for rpc_backward() here. + # This is currently broken since hivemind==1.1.8 is not compatible with bitsandbytes==0.39.1. + # Please revert to BLOCKWISE_8BIT once this is fixed: https://github.com/learning-at-home/hivemind/issues/572 return metadata From 4d9c26fe5c46e7f2f669c9d6e6e378f5934f4605 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 5 Jul 2023 14:57:59 +0400 Subject: [PATCH 098/168] Allow free_disk_space_for() remove arbitrary files from Petals cache (#339) Before this PR, `free_disk_space_for()` was able to remove **(a)** only entire cached revisions (= git commits/branches) and **(b)** only from the repository we're loading right now. This PR allows this functions to remove arbitrary files separately from any repositories. This is useful for transition to Petals 1.2.0+, since it now uses original repos instead of the ones with converted models (see #323). In particular, the cache for `bigscience/bloom-petals` is now deprecated and should be removed in favor of `bigscience/bloom`. This is also useful as a way to free space before loading LoRA adapters (#335). --- src/petals/server/from_pretrained.py | 2 +- src/petals/utils/disk_cache.py | 33 ++++++++++++---------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 62b9959..41fb989 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -153,7 +153,7 @@ def _load_state_dict_from_file( url = hf_hub_url(model_name, filename, revision=revision) file_size = get_hf_file_metadata(url, token=use_auth_token).size if file_size is not None: - free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) + free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) else: logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}") diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py index aefea1d..a26a0f5 100644 --- a/src/petals/utils/disk_cache.py +++ b/src/petals/utils/disk_cache.py @@ -33,15 +33,12 @@ def allow_cache_reads(cache_dir: Optional[str]): return _blocks_lock(cache_dir, fcntl.LOCK_SH) -def allow_cache_writes( - cache_dir: Optional[str], *, reserve: Optional[int] = None, max_disk_space: Optional[int] = None -): +def allow_cache_writes(cache_dir: Optional[str]): """Allows saving new blocks and removing the old ones (exclusive lock)""" return _blocks_lock(cache_dir, fcntl.LOCK_EX) def free_disk_space_for( - model_name: str, size: int, *, cache_dir: Optional[str], @@ -51,35 +48,33 @@ def free_disk_space_for( if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR cache_info = huggingface_hub.scan_cache_dir(cache_dir) - model_repos = [repo for repo in cache_info.repos if repo.repo_type == "model" and repo.repo_id == model_name] - occupied_space = sum(repo.size_on_disk for repo in model_repos) available_space = shutil.disk_usage(cache_dir).free - os_quota if max_disk_space is not None: - available_space = min(available_space, max_disk_space - occupied_space) + available_space = min(available_space, max_disk_space - cache_info.size_on_disk) gib = 1024**3 logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB") if size <= available_space: return - revisions = [revision for repo in model_repos for revision in repo.revisions] - revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified)) + cached_files = [file for repo in cache_info.repos for revision in repo.revisions for file in revision.files] - # Remove as few least recently used shards as possible - pending_removal = [] + # Remove as few least recently used files as possible + removed_files = [] freed_space = 0 extra_space_needed = size - available_space - for rev in revisions: - pending_removal.append(rev.commit_hash) - freed_space += rev.size_on_disk + for file in sorted(cached_files, key=lambda file: file.blob_last_accessed): + os.remove(file.file_path) # Remove symlink + os.remove(file.blob_path) # Remove contents + + removed_files.append(file) + freed_space += file.size_on_disk if freed_space >= extra_space_needed: break - - if pending_removal: - logger.info(f"Removing {len(pending_removal)} shards to free {freed_space / gib:.1f} GiB of disk space") - delete_strategy = cache_info.delete_revisions(*pending_removal) - delete_strategy.execute() + if removed_files: + logger.info(f"Removed {len(removed_files)} files to free {freed_space / gib:.1f} GiB of disk space") + logger.debug(f"Removed paths: {[str(file.file_path) for file in removed_files]}") if freed_space < extra_space_needed: raise RuntimeError( From 158013a6715a97686493bddf1904ba6455a2a6b6 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 11 Jul 2023 17:29:34 +0400 Subject: [PATCH 099/168] Implement direct server-to-server communication (#331) Implement #226. --- src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 3 +- src/petals/client/inference_session.py | 228 ++++++++++-------- src/petals/client/routing/sequence_manager.py | 1 + src/petals/server/handler.py | 200 +++++++++++++-- src/petals/server/server.py | 32 ++- 6 files changed, 334 insertions(+), 132 deletions(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 26aa3ab..f007d11 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -9,7 +9,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev0" +__version__ = "1.2.0.dev1" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 3c28709..1d3c438 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -27,8 +27,7 @@ def main(): parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve") parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve") - parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default," - "use the same name as in the converted model.") + parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix") parser.add_argument('--port', type=int, required=False, help='Port this server listens to. ' diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 168dd40..8c2dfc9 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -3,7 +3,8 @@ from __future__ import annotations import asyncio import itertools import time -from typing import AsyncIterator, List, Optional +import uuid +from typing import AsyncIterator, List, Optional, Tuple import torch from hivemind import ( @@ -15,10 +16,10 @@ from hivemind import ( serialize_torch_tensor, ) from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import StubBase +from hivemind.p2p import P2P from hivemind.proto import runtime_pb2 -from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback +from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy @@ -35,35 +36,48 @@ class _ServerInferenceSession: def __init__( self, + config: SequenceManagerConfig, + span: RemoteSpanInfo, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, *, - timeout: float, max_length: int, **metadata, ): - self.uid, self.rpc_info = uid, rpc_info + self.config = config + self.span, self.uid, self.rpc_info = span, uid, rpc_info self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter - self.timeout = timeout - self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata)) + self.session_id = str(uuid.uuid4()) + self.session_metadata = dict(max_length=max_length, **metadata) self.stepped = False self.closed = False + self._position = 0 + self.history = None # Used in case of server failures to regenerate attention caches on new servers + self.next_session = None + @classmethod async def create( - cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata + cls, + config: SequenceManagerConfig, + p2p: P2P, + span: RemoteSpanInfo, + uid: ModuleUID, + rpc_info: RPCInfo, + **metadata, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" + stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), - timeout, + config.request_timeout, ) - return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata) + return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: @@ -75,9 +89,11 @@ class _ServerInferenceSession: def step( self, - new_hidden_states: torch.Tensor, + inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, + *, + step_id: str, ) -> torch.Tensor: """ Inference step: send a chunk of input tesors and receive a chunk of outputs @@ -86,44 +102,84 @@ class _ServerInferenceSession: """ if self.closed: raise Exception("Session is closed, cannot perform step") + + n_input_tokens = inputs.shape[1] + if self.history is None: + self.history = inputs + elif self.history.shape[1] == self._position: + self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1) + assert self.history.shape[1] == self._position + n_input_tokens, ( + f"Broken input cache: span={self.span} shape={self.history.shape} " + f"position={self._position} n_input_tokens={n_input_tokens}" + ) + + if not self.stepped: + inputs = self.history # Pass full inputs including prefix + else: + inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + if prompts is None or is_dummy(prompts): prompts = DUMMY else: - assert prompts.ndim == 4, "deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]" + assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks - assert prompts.shape[1] in (new_hidden_states.shape[0], 1) - assert prompts.shape[2] <= new_hidden_states.shape[1] - assert prompts.shape[3] == new_hidden_states.shape[2] + assert prompts.shape[1] in (inputs.shape[0], 1) + assert prompts.shape[2] <= inputs.shape[1] + assert prompts.shape[3] == inputs.shape[2] if hypo_ids is None or is_dummy(hypo_ids): hypo_ids = DUMMY else: - assert len(hypo_ids) == len(new_hidden_states) + assert len(hypo_ids) == len(inputs) assert hypo_ids.dtype == torch.int64 # serialize inputs and put them into the queue - inputs = (new_hidden_states, prompts, hypo_ids) + input_tensors = (inputs, prompts, hypo_ids) + + request_metadata = dict(session_id=self.session_id, step_id=step_id) + if not self.stepped: + request_metadata.update(self.session_metadata) + elif self.config.use_server_to_server: + next_servers = self._collect_next_servers() + if next_servers: + request_metadata["next_servers"] = next_servers + outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=self.uid, tensors=[ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"])) + for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"])) ], - metadata=self._serialized_metadata if not self.stepped else None, + metadata=MSGPackSerializer.dumps(request_metadata), ) ) ) outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors)) - assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}" + assert ( + outputs[0].shape == inputs.shape + ), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}" + + self._position += n_input_tokens + return outputs[0] + def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]: + next_servers = [] + session = self.next_session + while session is not None and session.stepped: + next_servers.append( + (session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end) + ) + session = session.next_session + return next_servers + async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse: """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True - return await asyncio.wait_for(anext(self._outputs_stream), self.timeout) + return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout) def close(self): """Finish a given inference session, close the underlying connection""" @@ -163,13 +219,15 @@ class InferenceSession: def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): self._sequence_manager = sequence_manager self._closed = False - self._chosen_spans = [] self._server_sessions = [] - self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers self._position = 0 self._max_length = max_length self.last_token_id = None + @property + def num_blocks(self) -> int: + return len(self._sequence_manager) + @property def position(self) -> int: return self._position @@ -178,15 +236,15 @@ class InferenceSession: server_sessions = [] try: for span in chosen_spans: - stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id) span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( - stub, + self._sequence_manager.config, + self._sequence_manager.state.p2p, + span, span_uids, rpc_info=self._sequence_manager.rpc_info, - timeout=self._sequence_manager.config.request_timeout, max_length=self._max_length, **metadata, ) @@ -206,7 +264,7 @@ class InferenceSession: logger.debug("Caught exception while closing connection to server:", exc_info=True) def __enter__(self) -> "InferenceSession": - assert not self._closed and not self._chosen_spans + assert not self._closed and not self._server_sessions return self def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: @@ -214,16 +272,17 @@ class InferenceSession: if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") - n_blocks = len(self._sequence_manager) if prompts is None or is_dummy(prompts): prompts = DUMMY else: - assert prompts.ndim == 4 and prompts.shape[0] == n_blocks + assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" + assert prompts.shape[0] == self.num_blocks inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() + step_id = str(uuid.uuid4()) n_input_tokens = inputs.shape[1] if self._position + n_input_tokens > self._max_length: @@ -233,97 +292,74 @@ class InferenceSession: server_idx = 0 block_idx = 0 - recovery_until = -1 # Recovery mode is disabled until a failure happens - while block_idx < n_blocks: + while block_idx < self.num_blocks: for attempt_no in itertools.count(): logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}") - span = None + server_session = None try: - if not self._chosen_spans or not self._server_sessions or attempt_no >= 1: - # If there is a failed server session, this code closes it - self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) - - n_prev_spans = len(self._chosen_spans) - update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks - if attempt_no >= 1 and update_end > recovery_until: - logger.info( - f"Due to a server failure, remote attention caches " - f"from block {block_idx} to {update_end} will be regenerated" - ) - recovery_until = max(recovery_until, update_end) - - updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency") - # make_sequence() could return a longer sequence - updated_spans[-1].end = min(updated_spans[-1].end, update_end) - updated_sessions = self._enter_server_sessions(updated_spans) - logger.debug( - f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers" - ) - - # If there is a failed span, this code replaces it, otherwise it just adds new ones - self._chosen_spans[server_idx : server_idx + 1] = updated_spans - self._server_sessions[server_idx : server_idx + 1] = updated_sessions - recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None - self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * ( - len(updated_spans) - 1 - ) - assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), ( - f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " - f"{len(self._server_inputs)} inputs" - ) - - session = self._server_sessions[server_idx] - span = self._chosen_spans[server_idx] - - if self._server_inputs[server_idx] is None: - self._server_inputs[server_idx] = inputs - elif self._server_inputs[server_idx].shape[1] == self._position: - self._server_inputs[server_idx] = torch.cat( - [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1 - ) - assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, ( - f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " - f"position={self._position} n_input_tokens={n_input_tokens}" - ) - - if not session.stepped: - inputs = self._server_inputs[server_idx] # Pass full inputs including prefix - else: - inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + if not self._server_sessions or attempt_no >= 1: + self._update_sequence(server_idx, block_idx, attempt_no) - outputs = session.step(inputs, prompts[span.start : span.end], **kwargs) - assert ( - inputs.shape == outputs.shape - ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})" + server_session = self._server_sessions[server_idx] + inputs = server_session.step( + inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs + ) - inputs = outputs server_idx += 1 - block_idx = span.end - self._sequence_manager.on_request_success(span.peer_id) + block_idx = server_session.span.end + self._sequence_manager.on_request_success(server_session.span.peer_id) break except Exception as e: - self._sequence_manager.on_request_failure(span.peer_id if span is not None else None) + self._sequence_manager.on_request_failure( + server_session.span.peer_id if server_session is not None else None + ) if attempt_no + 1 == self._sequence_manager.config.max_retries: raise delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( - f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}" + f"Caught exception when running inference via {server_session.span if server_session is not None else None} " + f"(retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) time.sleep(delay) self._position += n_input_tokens - inputs = inputs[:, -n_input_tokens:] - outputs = inputs.to(device=inputs_device, dtype=inputs_dtype) + outputs = inputs[:, -n_input_tokens:] + outputs = outputs.to(device=inputs_device, dtype=inputs_dtype) return outputs + def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int: + # If there is a failed server session, this code closes it + self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) + + n_prev_spans = len(self._server_sessions) + update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks + if attempt_no >= 1: + logger.info( + f"Due to a server failure, remote attention caches " + f"from block {block_idx} to {update_end} will be regenerated" + ) + + updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency") + # make_sequence() could return a longer sequence + updated_spans[-1].end = min(updated_spans[-1].end, update_end) + updated_sessions = self._enter_server_sessions(updated_spans) + logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers") + + # If there is a failed span, this code replaces it, otherwise it just adds new ones + if server_idx < n_prev_spans: + updated_sessions[0].history = self._server_sessions[server_idx].history + self._server_sessions[server_idx : server_idx + 1] = updated_sessions + + # Update links to the next server session for direct server-to-server communication via rpc_push() + for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)): + self._server_sessions[i].next_session = self._server_sessions[i + 1] + def close(self, *exc_details): """Finish a given inference session, close the underlying connection""" if not self._closed: - self._server_inputs.clear() self._exit_server_sessions(self._server_sessions) self._server_sessions.clear() - self._chosen_spans.clear() self._closed = True def __exit__(self, *exc_details): diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 1a31d66..88d6d16 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -34,6 +34,7 @@ class SequenceManagerConfig: daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers + use_server_to_server: bool = True # Use direct server-to-server communication request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 79376f8..65ee5c6 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -2,6 +2,9 @@ from __future__ import annotations import asyncio import contextlib +import multiprocessing.managers +import sys +from concurrent.futures import ThreadPoolExecutor from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -11,6 +14,7 @@ from hivemind import ( DHT, MSGPackSerializer, P2PContext, + PeerID, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, @@ -25,7 +29,7 @@ from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals -from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID from petals.server.backend import TransformerBackend from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool @@ -34,6 +38,23 @@ from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__name__) + +# Fix pickling protobufs, see https://stackoverflow.com/a/74873028 +sys.modules["runtime_pb2"] = runtime_pb2 + +# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256 + +_OriginalAutoProxy = multiprocessing.managers.AutoProxy + + +def patched_autoproxy(*args, manager_owned=True, **kwargs): + # Calling original AutoProxy without the unwanted key argument + return _OriginalAutoProxy(*args, **kwargs) + + +multiprocessing.managers.AutoProxy = patched_autoproxy + + CACHE_TOKENS_AVAILABLE = "cache_tokens_available" @@ -47,6 +68,9 @@ class TransformerConnectionHandler(ConnectionHandler): dht: DHT, module_backends: Dict[str, TransformerBackend], *, + dht_prefix: str, + push_manager: multiprocessing.managers.SyncManager, + session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue inference_max_length: int, request_timeout: float, session_timeout: float, @@ -56,6 +80,11 @@ class TransformerConnectionHandler(ConnectionHandler): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) + self.dht_prefix = dht_prefix + self._push_manager = push_manager + self._session_queues = session_queues + self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues + self.inference_max_length = inference_max_length self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout @@ -96,7 +125,7 @@ class TransformerConnectionHandler(ConnectionHandler): self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext, - ) -> AsyncIterator[runtime_pb2.ExpertRequest]: + ) -> AsyncIterator[runtime_pb2.ExpertResponse]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" async with timeout(self.session_timeout): @@ -113,6 +142,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") points = metadata.get("points", 0) + session_id = metadata.get("session_id") if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") @@ -133,7 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler): async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: assert len(cache_handles) == len(requested_backends) - while request.tensors: # iterate while user is willing to supply tensors + first_request = request + background_tasks = set() + async for request, metadata in self._iterate_inference_steps( + first_request, requests, session_id, requested_uids, context + ): hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) # Cast inputs to backend dtype @@ -141,7 +175,8 @@ class TransformerConnectionHandler(ConnectionHandler): assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" # parse deep prompts (optional argument) - if prompts is None or is_dummy(prompts): + has_prompts = prompts is not None and not is_dummy(prompts) + if not has_prompts: prompts = [None] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] @@ -180,25 +215,136 @@ class TransformerConnectionHandler(ConnectionHandler): ) # serialize and send last layer outputs - yield runtime_pb2.ExpertResponse( - tensors=[ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip( - (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) - ) - ] - ) + output_tensors = [ + serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) + for result, proto in zip( + (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) + ) + ] + if not has_prompts: + task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) + background_tasks.add(task) # Keep reference until it is done to save it from GC + task.add_done_callback(background_tasks.discard) + yield runtime_pb2.ExpertResponse(tensors=output_tensors) # prepare for next step - prefix_length += hidden_states.shape[1] - try: - request = await asyncio.wait_for(anext(requests), self.step_timeout) - except asyncio.TimeoutError: - self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") - return + prefix_length += length_increment finally: self._log_request("rpc_inference.close", requested_uids, context) + async def _iterate_inference_steps( + self, + first_request: runtime_pb2.ExpertRequest, + requests: AsyncIterator[runtime_pb2.ExpertRequest], + session_id: Optional[str], + requested_uids: Sequence[str], + context: P2PContext, + ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]: + loop = asyncio.get_event_loop() + if session_id is not None: + push_queue = self._push_manager.Queue() + self._session_queues[session_id] = push_queue + + processed_step_ids = set() + n_pushes = n_late_pushes = 0 + request = first_request + anext_task = get_push_task = None + try: + while request.tensors: # iterate while user is willing to supply tensors + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + step_id = metadata.get("step_id") + + pushed = metadata.get("pushed") + if pushed: + n_pushes += 1 + + if step_id is None or step_id not in processed_step_ids: + yield request, metadata + if step_id is not None: + processed_step_ids.add(step_id) + elif pushed: + n_late_pushes += 1 + self._log_request( + "rpc_inference.push", + requested_uids, + context, + warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", + ) + + # Wait for the next request, coming either from the `requests` iterator or `push_queue` + if anext_task is None: + anext_task = asyncio.create_task(anext(requests)) + if get_push_task is None: + if session_id is not None: + get_push_task = loop.run_in_executor(self._executor, push_queue.get) + else: + get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task + done, _ = await asyncio.wait( + [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED + ) + + if anext_task in done: + request = await anext_task + anext_task = None + elif get_push_task in done: + request = await get_push_task + get_push_task = None + else: + self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") + anext_task.cancel() + get_push_task.cancel() + return + except: + logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) + raise + finally: + if session_id is not None: + push_queue.put(None) # Stop thread for get_push_task + del self._session_queues[session_id] + + async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: + """Directly push activation tensors from one server to another""" + + requested_uids = self._check_uids(request.uid) + self._log_request("rpc_push", requested_uids, context) + + metadata = MSGPackSerializer.loads(request.metadata) + session_id = metadata["session_id"] + self._session_queues[session_id].put(request) + return runtime_pb2.ExpertResponse() + + async def _push_outputs( + self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict + ) -> None: + try: + next_servers = metadata.get("next_servers") + if not next_servers: + return + + next_peer_id, next_session_id, next_start, next_end = next_servers[0] + next_peer_id = PeerID.from_base58(next_peer_id) + next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end)) + + # Sending hidden states serialized with output_schema to avoid double serialization + next_tensors = [serialized_outputs] + request.tensors[1:] + next_metadata = metadata.copy() + next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True) + + stub = self.get_stub(self._p2p, next_peer_id) + await stub.rpc_push( + runtime_pb2.ExpertRequest( + uid=next_uid, + tensors=next_tensors, + metadata=MSGPackSerializer.dumps(next_metadata), + ), + timeout=self.request_timeout, + ) + except Exception: + logger.debug( + f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:", + exc_info=True, + ) + async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async with timeout(self.request_timeout): # Parse request and prepare backends @@ -348,7 +494,7 @@ class TransformerConnectionHandler(ConnectionHandler): @contextlib.asynccontextmanager async def _allocate_cache( self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int - ) -> Sequence[Sequence[Handle, ...]]: + ) -> Sequence[Sequence[Handle]]: """ Allocate memory cache for all transformer blocks, return cache handle :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend @@ -358,7 +504,13 @@ class TransformerConnectionHandler(ConnectionHandler): yield nested_pack(handles, descriptors) def _log_request( - self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None + self, + method: str, + uids: Optional[Sequence[ModuleUID]], + context: P2PContext, + *, + debug: Optional[str] = None, + warning: Optional[str] = None, ) -> None: if uids is not None: friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] @@ -370,10 +522,12 @@ class TransformerConnectionHandler(ConnectionHandler): friendly_remote_id = "..." + str(context.remote_id)[-6:] message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})" - if warning is None: - logger.info(message) - else: + if warning is not None: logger.warning(f"{message}: {warning}") + elif debug is not None: + logger.debug(f"{message}: {debug}") + else: + logger.info(message) async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: """Return metadata about stored block uids and current load""" diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 2fbaad2..894e9ea 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -45,7 +45,7 @@ class Server: self, *, initial_peers: List[str], - prefix: Optional[str], + dht_prefix: Optional[str], converted_model_name_or_path: str, throughput: Union[float, str], num_blocks: Optional[int] = None, @@ -105,13 +105,13 @@ class Server: revision=revision, ) - if prefix is None: - prefix = self.block_config.dht_prefix - assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( + if dht_prefix is None: + dht_prefix = self.block_config.dht_prefix + assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, ( f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. " - f"Please specify another --prefix manually when starting a server" + f"Please specify another --dht_prefix manually when starting a server" ) - self.prefix = prefix + self.dht_prefix = dht_prefix if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) @@ -121,7 +121,8 @@ class Server: self.session_timeout, self.step_timeout = session_timeout, step_timeout self.module_uids = [ - f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers) + f"{self.dht_prefix}{UID_DELIMITER}{block_index}" + for block_index in range(self.block_config.num_hidden_layers) ] if dht_client_mode is None: @@ -258,7 +259,7 @@ class Server: block_indices = self._choose_blocks() self.module_container = ModuleContainer.create( dht=self.dht, - prefix=self.prefix, + dht_prefix=self.dht_prefix, converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, @@ -359,7 +360,7 @@ class ModuleContainer(threading.Thread): cls, *, dht: DHT, - prefix: str, + dht_prefix: str, converted_model_name_or_path: str, block_config: PretrainedConfig, attn_cache_bytes: int, @@ -382,7 +383,7 @@ class ModuleContainer(threading.Thread): should_validate_reachability: bool, **kwargs, ) -> ModuleContainer: - module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] + module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] joining_announcer = ModuleAnnouncerThread( module_uids, dht, @@ -459,6 +460,7 @@ class ModuleContainer(threading.Thread): return cls( dht, + dht_prefix, blocks, throughput=throughput, update_period=update_period, @@ -469,6 +471,7 @@ class ModuleContainer(threading.Thread): def __init__( self, dht: DHT, + dht_prefix: str, module_backends: Dict[str, TransformerBackend], *, inference_max_length: int, @@ -486,10 +489,17 @@ class ModuleContainer(threading.Thread): self.dht, self.module_backends = dht, module_backends self.throughput, self.update_period, self.expiration = throughput, update_period, expiration + + self.push_manager = mp.Manager() + self.push_manager.__enter__() + session_queues = self.push_manager.dict() self.conn_handlers = [ TransformerConnectionHandler( dht, self.module_backends, + dht_prefix=dht_prefix, + push_manager=self.push_manager, + session_queues=session_queues, inference_max_length=inference_max_length, request_timeout=request_timeout, session_timeout=session_timeout, @@ -497,6 +507,7 @@ class ModuleContainer(threading.Thread): ) for _ in range(num_handlers) ] + self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. self.online_announcer = ModuleAnnouncerThread( @@ -577,6 +588,7 @@ class ModuleContainer(threading.Thread): logger.debug("Shutting down connection handlers") for handler in self.conn_handlers: handler.shutdown() + self.push_manager.__exit__(None, None, None) logger.debug(f"Shutting down pools") for pool in self.runtime.pools: From fa095f6461c50f600f950ba18deaf633d804c68e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 11 Jul 2023 18:53:17 +0400 Subject: [PATCH 100/168] Use 4-bit for llama by default, use bitsandbytes 0.40.0.post3 (#340) NF4 inference with bitsandbytes 0.40.0.post3 is ~2x faster than int8 inference, though training is still ~3x slower, see: - [bitsandbytes 0.40.0 Release notes](https://github.com/TimDettmers/bitsandbytes/releases/tag/0.40.0) - [RPS benchmarks](https://github.com/bigscience-workshop/petals/pull/333#issuecomment-1614040385) We've decided to use NF4 by default for LLaMA. --- setup.cfg | 2 +- src/petals/server/server.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index fb1fa23..76185eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.39.1 + bitsandbytes==0.40.0.post3 accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 894e9ea..eddb76e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -166,7 +166,10 @@ class Server: check_device_balance(self.tensor_parallel_devices) if quant_type is None: - quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE + if device.type == "cuda": + quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8 + else: + quant_type = QuantType.NONE self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") From b28f5016ea0a8d90a4b82c56c213a8e3e350b59e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 11 Jul 2023 21:42:35 +0400 Subject: [PATCH 101/168] Delete deprecated petals.cli scripts (#336) --- src/petals/cli/deploy_server.sh | 79 ------------- src/petals/cli/inference_one_block.py | 51 -------- .../cli/local_server_config_example.cfg | 5 - .../cli/remote_server_config_example.cfg | 6 - src/petals/cli/run_local_servers.sh | 109 ----------------- src/petals/cli/run_remote_servers.sh | 110 ------------------ 6 files changed, 360 deletions(-) delete mode 100644 src/petals/cli/deploy_server.sh delete mode 100644 src/petals/cli/inference_one_block.py delete mode 100644 src/petals/cli/local_server_config_example.cfg delete mode 100644 src/petals/cli/remote_server_config_example.cfg delete mode 100644 src/petals/cli/run_local_servers.sh delete mode 100644 src/petals/cli/run_remote_servers.sh diff --git a/src/petals/cli/deploy_server.sh b/src/petals/cli/deploy_server.sh deleted file mode 100644 index 0bea785..0000000 --- a/src/petals/cli/deploy_server.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env bash - -################# -# Parse options # -################# - -instructions() { - echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2 - echo " -m: model name" - echo " -i: initial peer" - echo " -d: device" >&2 - echo " -p: server identity path" >&2 - echo " -b: block_ids" >&2 - echo " -a: host maddrs" >&2 - echo " -t: whether to run local tests" >&2 - exit 1 -} - -if [ ! $# -ge 8 ]; then - instructions -fi - -while getopts ":m:i:d:p:b:a:t:" option; do - case $option in - m) MODEL_NAME=${OPTARG} - ;; - i) INITIAL_PEER=${OPTARG} - ;; - d) DEVICE=${OPTARG} - ;; - p) SERVER_ID_PATH=${OPTARG} - ;; - b) BLOCK_IDS=${OPTARG} - ;; - a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs - ;; - t) RUN_LOCAL_TESTS=true - ;; - \?) instructions - ;; - esac -done - - -echo "==========" -echo "= Config =" -echo "==========" -echo "Model name: ${MODEL_NAME}" -echo "Initial peer: ${INITIAL_PEER}" -echo "Device: ${DEVICE}" -echo "Server name: ${SERVER_ID_PATH}" -echo "Server address: ${HOST_MADDR}" -echo "Bloom blocks: ${BLOCK_IDS}" - - -########################### -# Install or activate env # -########################### - -# TODO fix bug with self calling -source ~/miniconda3/etc/profile.d/conda.sh -if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then - conda activate bloom-demo -else - conda create -y --name bloom-demo python=3.8.12 pip - conda activate bloom-demo - - conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 - pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install -i https://pypi.org/simple -r . - pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 -fi - -############## -# Run server # -############## - -python -m petals.cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \ - --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py deleted file mode 100644 index 6d53e9b..0000000 --- a/src/petals/cli/inference_one_block.py +++ /dev/null @@ -1,51 +0,0 @@ -import argparse - -import torch -from hivemind.utils.logging import get_logger -from tqdm.auto import trange -from transformers import BloomConfig -from transformers.models.bloom.modeling_bloom import build_alibi_tensor - -from petals.models.bloom.block import BloomBlock - -logger = get_logger(__name__) - -logger.warning("inference_one_block will soon be deprecated in favour of tests!") - - -def print_device_info(device=None): - """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528""" - device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) - logger.info(f"Using device: {device}") - - # Additional Info when using cuda - if device.type == "cuda": - logger.info(torch.cuda.get_device_name(0)) - logger.info(f"Memory Usage:") - logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB") - logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data") - parser.add_argument("--config", required=True, type=str, help="Path to a config json file") - parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict") - parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run") - parser.add_argument("--device", default=None, type=str, help="Run inference on this device") - args = parser.parse_args() - - if args.device is None: - args.device = "cuda" if torch.cuda.is_available() else "cpu" - - config = BloomConfig.from_json_file(args.config) - block = BloomBlock(config).to(args.device) - - cache = None - - for i in trange(args.num_steps): - dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device) - alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device) - with torch.no_grad(): - outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache) - - print_device_info(args.device) diff --git a/src/petals/cli/local_server_config_example.cfg b/src/petals/cli/local_server_config_example.cfg deleted file mode 100644 index 8cbfe45..0000000 --- a/src/petals/cli/local_server_config_example.cfg +++ /dev/null @@ -1,5 +0,0 @@ -device=cpu -block_ids=2:3 -id_path=./server.id -maddr=/ip4/127.0.0.1/tcp/30000 -# diff --git a/src/petals/cli/remote_server_config_example.cfg b/src/petals/cli/remote_server_config_example.cfg deleted file mode 100644 index 54df7af..0000000 --- a/src/petals/cli/remote_server_config_example.cfg +++ /dev/null @@ -1,6 +0,0 @@ -name=bloom-peer-0.bloom.net -device=cpu -block_ids=1:3 -id_path=./server.id -maddr=/ip4/0.0.0.0/tcp/30000 -# \ No newline at end of file diff --git a/src/petals/cli/run_local_servers.sh b/src/petals/cli/run_local_servers.sh deleted file mode 100644 index 0e449cb..0000000 --- a/src/petals/cli/run_local_servers.sh +++ /dev/null @@ -1,109 +0,0 @@ -# !/usr/bin/env bash - -################# -# Parse options # -################# - -instructions() { - echo "Usage: $0 [-n] [-c]" >&2 - echo " -n: number of servers to run" >&2 - echo " -c: path to the server configs" >&2 - exit 1 -} - -if [ $# != 4 ]; then - instructions -fi - -while getopts ":n:c:t:" option; do - case $option in - n) NUM_SERVERS=${OPTARG} - ;; - c) CONFIG_PATH=${OPTARG} - ;; - \?) instructions - ;; - esac -done - - -########################### -# Install or activate env # -########################### - -source ~/miniconda3/etc/profile.d/conda.sh -if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then - conda activate bloom-demo -else - conda create -y --name bloom-demo python=3.8.12 pip - conda activate bloom-demo - - conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 - pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install -i https://pypi.org/simple -r . - pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 -fi - - -####################### -# Create Initial peer # -####################### - -hivemind-dht &> tmp.out & -sleep 5 -INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" ) -echo "Initial peer: ${INITIAL_PEER}" - - -############################## -# Initialize the config file # -############################## - -typeset -A cfg -cfg=( # set default values in config array - [device]="cpu" - [block_ids]="1:2" - [id_path]="server.id" - [maddr]="/ip4/127.0.0.1/tcp/30000" -) - -############### -# Run servers # -############### - -for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) ) -do - ############### - # Read config # - ############### - - while read line - do - if echo $line | grep -F = &>/dev/null - then - varname=$(echo "$line" | cut -d '=' -f 1) - cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-) - fi - done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg - - echo "=== Server #${SERVER_ID} ===" - echo "Server ID: ${cfg[id_path]}" - echo "Device: ${cfg[device]}" - echo "Bloom block ids: ${cfg[block_ids]}" - echo "Host maddr: ${cfg[maddr]}" - echo "" - - ############## - # Run server # - ############## - - tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]} -done - -##################### -# Kill initial peer # -##################### - -sleep 10 -pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht -rm tmp.out \ No newline at end of file diff --git a/src/petals/cli/run_remote_servers.sh b/src/petals/cli/run_remote_servers.sh deleted file mode 100644 index e3f30de..0000000 --- a/src/petals/cli/run_remote_servers.sh +++ /dev/null @@ -1,110 +0,0 @@ -# !/usr/bin/env bash - -SSH_KEY_PATH="~/.ssh/" - -################# -# Parse options # -################# - -instructions() { - echo "Usage: $0 [-u] [-n] [-c]" >&2 - echo " -u: username" >&2 - echo " -n: number of servers to run" >&2 - echo " -c: path to the server configs" >&2 - exit 1 -} - -if [ $# != 6 ]; then - instructions -fi - -while getopts ":u:n:c:" option; do - case $option in - u) USERNAME=${OPTARG} - ;; - n) NUM_SERVERS=${OPTARG} - ;; - c) CONFIG_PATH=${OPTARG} - ;; - \?) instructions - ;; - esac -done - - -########################### -# Install or activate env # -########################### - -source ~/miniconda3/etc/profile.d/conda.sh -if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then - conda activate bloom-demo -else - conda create -y --name bloom-demo python=3.8.12 pip - conda activate bloom-demo - - conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 - pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install -i https://pypi.org/simple -r . -fi - - -####################### -# Create Initial peer # -####################### - -hivemind-dht &> tmp.out & - -sleep 5 -INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" ) -rm tmp.out -echo "Initial peer: ${INITIAL_PEER}" - - -############################## -# Initialize the config file # -############################## - -typeset -A cfg -cfg=( # set default values in config array - [name]="" - [device]="cpu" - [block_ids]="1:2" - [id_path]="server.id" - [maddr]="/ip4/0.0.0.0/tcp/30000" -) - -############### -# Run servers # -############### - -for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) ) -do - ############### - # Read config # - ############### - - while read line - do - if echo $line | grep -F = &>/dev/null - then - varname=$(echo "$line" | cut -d '=' -f 1) - cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-) - fi - done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg - - SERVER_NAME="${USERNAME}@${cfg[name]}" - echo "=== Server #${SERVER_ID} ===" - echo "Server name ${SERVER_NAME}" - echo "Server ID: ${cfg[id_path]}" - echo "Device: ${cfg[device]}" - echo "Bloom block ids: ${cfg[block_ids]}" - echo "Host maddr: ${cfg[maddr]}" - echo "=================" - - ############## - # Run server # - ############## - - ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'" -done \ No newline at end of file From dfc6578c8e406e0a05c7f87b0fc45a48cdd83584 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 12 Jul 2023 15:29:59 +0400 Subject: [PATCH 102/168] Use bitsandbytes 0.40.0.post4 with bias hotfix (#342) This PR includes a bnb hotfix: https://github.com/TimDettmers/bitsandbytes/commit/90b0ac57b0d8d8f996126deb8bba6b7dc75b4327 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 76185eb..6242651 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.40.0.post3 + bitsandbytes==0.40.0.post4 accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 From b9f0a5467fc67fe6e93d2901484dd5f36d60a316 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Wed, 12 Jul 2023 16:22:28 +0400 Subject: [PATCH 103/168] Support peft LoRA adapters (#335) Implement an option to deploy PEFT adapters to a server. Clients can set active_adapter=... to use these adapters. --------- Co-authored-by: Aleksandr Borzunov Co-authored-by: justheuristic --- .github/workflows/run-tests.yaml | 11 +- setup.cfg | 2 + src/petals/cli/run_server.py | 2 + src/petals/client/remote_sequential.py | 3 +- src/petals/client/routing/sequence_manager.py | 10 +- src/petals/data_structures.py | 3 +- src/petals/dht_utils.py | 30 ++- src/petals/server/backend.py | 27 ++- src/petals/server/handler.py | 38 +++- src/petals/server/server.py | 27 ++- src/petals/server/throughput.py | 2 +- src/petals/utils/convert_block.py | 25 ++- src/petals/utils/misc.py | 9 + src/petals/utils/peft.py | 208 ++++++++++++++++++ tests/test_full_model.py | 13 +- tests/test_peft.py | 66 ++++++ tests/test_utils.py | 2 + 17 files changed, 444 insertions(+), 34 deletions(-) create mode 100644 src/petals/utils/peft.py create mode 100644 tests/test_peft.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index fbb5b72..b98667e 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -33,10 +33,11 @@ jobs: run: | export MODEL_NAME=bigscience/bloom-560m export REF_NAME=bigscience/bloom-560m + export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log & + --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT @@ -45,17 +46,17 @@ jobs: # ^-- server 1 multiaddr is determined by --identity and --host_maddrs python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log & + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log & SERVER2_PID=$! sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log & + python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \ + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log & SERVER3_PID=$! python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log & + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log & SERVER4_PID=$! tail -n 100 -f server*.log & diff --git a/setup.cfg b/setup.cfg index 6242651..f56a7cc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,8 @@ install_requires = cpufeature>=0.2.0 packaging>=20.9 sentencepiece>=0.1.99 + peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735 + safetensors>=0.3.1 [options.extras_require] dev = diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 1d3c438..6b3fde8 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -146,6 +146,8 @@ def main(): help="Skip checking this server's reachability via health.petals.ml " "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") + + parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.") # fmt:on args = vars(parser.parse_args()) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 745b5c1..6ae664a 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -28,6 +28,7 @@ class RemoteSequential(nn.Module): dht: Optional[DHT] = None, start_block: Optional[int] = None, end_block: Optional[int] = None, + **kwargs, ): super().__init__() self.config = config @@ -41,7 +42,7 @@ class RemoteSequential(nn.Module): if end_block is None: end_block = self.config.num_hidden_layers block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) - sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) + sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs) self.sequence_manager = sequence_manager def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 88d6d16..fc505cc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -43,6 +43,7 @@ class SequenceManagerConfig: min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds + active_adapter: Optional[str] = None @dataclasses.dataclass @@ -78,6 +79,7 @@ class RemoteSequenceManager: *, dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, + active_adapter: Optional[str] = None, ): assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." @@ -115,7 +117,9 @@ class RemoteSequenceManager: if state.sequence_info.last_updated_time is None: # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records # in the first _update() instead of the latest ones. This makes the first .update() faster. - petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True) + petals.dht_utils.get_remote_module_infos( + self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True + ) self._need_latest_infos = False else: assert block_uids == state.sequence_info.block_uids @@ -179,7 +183,7 @@ class RemoteSequenceManager: def _update(self): """Perform an immediate and synchronous refresh, may take time""" new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, latest=self._need_latest_infos + self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos ) self._need_latest_infos = True # All future _update() should use latest infos @@ -307,7 +311,7 @@ class RemoteSequenceManager: :param kwargs: additional request context, such as remote peer ID :returns: msgpack-serialized metadata dict that will be passed alongside a given request """ - return dict(points=self.policy.get_points(protocol, *args, **kwargs)) + return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter) def shutdown(self): self._thread.shutdown() diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 80b8f62..254faae 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple from hivemind import PeerID from hivemind.moe.expert_uid import ExpertUID @@ -57,3 +57,4 @@ class InferenceMetadata: uid: ExpertUID prefix_length: int cache_handles: Tuple[Handle, ...] + active_adapter: Optional[str] diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 177b2f6..99316f2 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -22,6 +22,7 @@ def declare_active_modules( expiration_time: DHTExpiration, state: ServerState, throughput: float, + adapters: Optional[Sequence[str]] = None, wait: bool = True, ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: """ @@ -39,6 +40,7 @@ def declare_active_modules( uids = list(uids) for uid in uids: assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid + return dht.run_coroutine( partial( _declare_active_modules, @@ -46,6 +48,7 @@ def declare_active_modules( expiration_time=expiration_time, state=state, throughput=throughput, + adapters=list(adapters or []), ), return_future=not wait, ) @@ -58,12 +61,13 @@ async def _declare_active_modules( expiration_time: DHTExpiration, state: ServerState, throughput: float, + adapters: List[str], ) -> Dict[ModuleUID, bool]: num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) return await node.store_many( keys=uids, subkeys=[dht.peer_id.to_base58()] * len(uids), - values=[(state.value, throughput)] * len(uids), + values=[(state.value, throughput, dict(adapters=adapters))] * len(uids), expiration_time=expiration_time, num_workers=num_workers, ) @@ -73,18 +77,30 @@ def get_remote_module_infos( dht: DHT, uids: Sequence[ModuleUID], expiration_time: Optional[DHTExpiration] = None, + active_adapter: Optional[str] = None, *, latest: bool = False, return_future: bool = False, ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: return dht.run_coroutine( - partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest), + partial( + _get_remote_module_infos, + uids=uids, + active_adapter=active_adapter, + expiration_time=expiration_time, + latest=latest, + ), return_future=return_future, ) async def _get_remote_module_infos( - dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool + dht: DHT, + node: DHTNode, + uids: List[ModuleUID], + active_adapter: Optional[str], + expiration_time: Optional[DHTExpiration], + latest: bool, ) -> List[Optional[RemoteModuleInfo]]: if latest: assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" @@ -105,7 +121,13 @@ async def _get_remote_module_infos( for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) - state, throughput = server_info.value + state, throughput = server_info.value[:2] + extra_info = server_info.value[2] if len(server_info.value) > 2 else {} + adapters = extra_info.get("adapters", []) + if bool(active_adapter) and active_adapter not in adapters: + logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") + continue + if not ( isinstance(state, int) and isinstance(throughput, float) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index adcd617..9e81170 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -2,8 +2,9 @@ from __future__ import annotations from collections import Counter from itertools import chain -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, Union +import peft import torch from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind.moe.expert_uid import ExpertUID @@ -80,6 +81,18 @@ class TransformerBackend(ModuleBackend): cache_tensors.extend((keys, values)) return cache_tensors + def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: + *inputs, active_adapter = inputs + if not self.load_adapter_(active_adapter): + raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded") + return super().forward(*inputs) + + def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: + *inputs, active_adapter = inputs + if not self.load_adapter_(active_adapter): + raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded") + return super().backward(*inputs) + @torch.inference_mode() def inference_step( self, @@ -88,6 +101,8 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" + if not self.load_adapter_(inference_info.active_adapter): + raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) @@ -139,6 +154,16 @@ class TransformerBackend(ModuleBackend): for p in self.module.parameters(): p.data = dummy + def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: + """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" + adapter_was_loaded = False + for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter + if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): + layer.active_adapter = active_adapter # empty string for no adapter + if active_adapter in layer.lora_A.keys(): + adapter_was_loaded = True + return adapter_was_loaded or not active_adapter + def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 65ee5c6..d7295ca 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) session_id = metadata.get("session_id") @@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler): ) inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles)) + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) for uid, handles in zip(requested_uids, cache_handles) ) @@ -354,13 +355,18 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" hidden_states = await _rpc_forward( - *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_inputs, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) @@ -376,13 +382,18 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" hidden_states = await _rpc_forward( - *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_inputs, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) # Split the serialized_output for streaming and respond to client @@ -422,13 +433,18 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" grads = await _rpc_backward( - *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_tensors, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) @@ -442,13 +458,18 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" grads = await _rpc_backward( - *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_tensors, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) # Split the serialized_grad_inputs for streaming and respond for tensor in self._serialize_grads(grads, requested_backends, metadata): @@ -553,6 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler): async def _rpc_forward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, ) -> torch.Tensor: @@ -585,6 +607,7 @@ async def _rpc_forward( ) (hidden_states,) = await backend.forward_pool.submit_task( hidden_states, + active_adapter, priority=priority, ) assert isinstance(hidden_states, torch.Tensor) @@ -598,6 +621,7 @@ async def _rpc_forward( async def _rpc_backward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: @@ -623,7 +647,7 @@ async def _rpc_backward( priority = prioritizer.prioritize( inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" ) - (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority) + (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) assert isinstance(inputs, torch.Tensor) @@ -639,7 +663,7 @@ async def _rpc_backward( priority = prioritizer.prioritize( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority) + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index eddb76e..643bf1b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -81,6 +81,7 @@ class Server: dht_client_mode: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, + adapters: Optional[List[str]] = None, **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -218,6 +219,8 @@ class Server: self.mean_balance_check_period = mean_balance_check_period self.mean_block_selection_delay = mean_block_selection_delay + self.adapters = adapters + self.stop = threading.Event() def _choose_num_blocks(self) -> int: @@ -291,6 +294,7 @@ class Server: quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, + adapters=self.adapters, start=True, ) try: @@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread): quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, + adapters: Optional[List[str]] = None, **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] @@ -391,6 +396,7 @@ class ModuleContainer(threading.Thread): module_uids, dht, ServerState.JOINING, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -415,7 +421,19 @@ class ModuleContainer(threading.Thread): cache_dir=cache_dir, max_disk_space=max_disk_space, ) - block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True) + block = convert_block( + block, + block_index, + block_config, + tensor_parallel_devices, + device, + quant_type, + adapters=adapters, + freeze=True, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + max_disk_space=max_disk_space, + ) blocks[module_uid] = TransformerBackend( module_uid, block, @@ -452,6 +470,7 @@ class ModuleContainer(threading.Thread): expiration_time=get_dht_time() + expiration, state=ServerState.OFFLINE, throughput=throughput, + adapters=adapters, ) logger.info(f"Announced that blocks {module_uids} are offline") raise @@ -465,6 +484,7 @@ class ModuleContainer(threading.Thread): dht, dht_prefix, blocks, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -480,6 +500,7 @@ class ModuleContainer(threading.Thread): inference_max_length: int, num_handlers: int, throughput: float, + adapters: Optional[Sequence[str]], update_period: float, expiration: Optional[float] = None, request_timeout: float, @@ -517,6 +538,7 @@ class ModuleContainer(threading.Thread): list(self.module_backends.keys()), dht, ServerState.ONLINE, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -616,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread): module_uids: List[str], dht: DHT, state: ServerState, + adapters: Optional[Sequence[str]], *, throughput: float, update_period: float = 30, @@ -626,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread): self.module_uids = module_uids self.dht = dht self.state = state + self.adapters = adapters self.throughput = throughput self.update_period = update_period self.expiration = expiration @@ -639,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread): expiration_time=get_dht_time() + self.expiration, state=self.state, throughput=self.throughput, + adapters=self.adapters, ) if self.stop.wait(self.update_period): break diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 76bbc85..20625e6 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -172,7 +172,7 @@ def measure_compute_rps( tensor_parallel_devices = (device,) with torch.inference_mode(): block = config.block_class(config).to(dtype) - block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) cache = None elapsed = 0 diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 6b129f5..b1c412e 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -3,8 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par """ import os import re -from enum import Enum -from typing import Sequence +from typing import List, Optional, Sequence import tensor_parallel as tp import torch @@ -13,23 +12,23 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig +from petals.utils.misc import QuantType +from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) -class QuantType(Enum): - NONE = 0 - INT8 = 1 # 8-bit as in the LLM.int8() paper - NF4 = 2 # 4-bit as in the QLoRA paper - - def convert_block( block: nn.Module, + block_index: int, config: PretrainedConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, quant_type: QuantType, freeze: bool = True, + adapters: Optional[List[str]] = None, + **kwargs, ) -> tp.TensorParallel: """ Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization @@ -56,6 +55,16 @@ def convert_block( for shard, device in zip(block.module_shards, block.devices): shard.to(device) + if adapters: + create_lora_adapter(block, quant_type=quant_type) + for adapter_name in adapters: + adapter_config, adapter_state_dict = load_peft( + adapter_name, + block_idx=block_index, + **kwargs, + ) + add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict) + return block diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index 2f67202..99b246c 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -1,5 +1,14 @@ +from enum import Enum + import torch + +class QuantType(Enum): + NONE = 0 + INT8 = 1 # 8-bit as in the LLM.int8() paper + NF4 = 2 # 4-bit as in the QLoRA paper + + DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py new file mode 100644 index 0000000..c551f97 --- /dev/null +++ b/src/petals/utils/peft.py @@ -0,0 +1,208 @@ +import re +import time +from typing import List, Optional + +import bitsandbytes as bnb +import torch.nn as nn +from hivemind.utils.logging import get_logger +from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url +from peft.tuners import lora +from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig +from safetensors import safe_open +from safetensors.torch import load_file +from transformers.utils import get_file_from_repo + +from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for +from petals.utils.misc import QuantType + +logger = get_logger(__name__) + + +def check_peft_repository(repo_id: str) -> bool: + fs = HfFileSystem() + list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False) + return len(list_of_files) > 0 + + +def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None): + tensors = dict() + is_tensors_found = dict() + common_layer_patter_re = ( + ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+" + ) + with safe_open(filepath, framework=framework, device=device) as f: + for k in f.keys(): + if re.match(common_layer_patter_re, k): + is_tensors_found[block_idx] = True + tensors[k] = f.get_tensor(k) + if not is_tensors_found.get(block_idx, False): + logger.warning(f"There is no peft weights for block {block_idx}") + return tensors + + +def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs): + config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs) + if config_path is None: + raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}") + config = PeftConfig.from_json_file(config_path) + + weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs) + if weight_path is None: + raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}") + if block_idx is None: + return config, load_file(weight_path) + return config, load_specific_module(block_idx, weight_path, device=device) + + +def load_peft( + repo_id: str, + block_idx: Optional[int] = None, + device: Optional[int] = None, + *, + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: str, + max_disk_space: Optional[int] = None, + delay: float = 30, +): + # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here + + if not check_peft_repository(repo_id): + raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.") + + try: + with allow_cache_reads(cache_dir): + return get_adapter_from_repo( + repo_id, + block_idx, + device, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=False, + ) + except Exception: + logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True) + + while True: + try: + with allow_cache_writes(cache_dir): + config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision) + config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size + weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) + weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size + + file_size = config_file_size + weight_file_size + if file_size is not None: + free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) + else: + logger.warning(f"Failed to fetch size from peft repo {repo_id}") + + return get_adapter_from_repo( + repo_id, + block_idx, + device, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=False, + ) + except Exception as e: + logger.warning( + f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True + ) + time.sleep(delay) + + +def create_lora_adapter(block, quant_type: QuantType): + for name, module in block.named_modules(): + for child_name, child in module.named_children(): + lora_wrapped_child = None + if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)): + continue + if quant_type == QuantType.INT8: + kwargs = { + "has_fp16_weights": False, + "threshold": 6.0, + "bias": hasattr(child, "bias") and child.bias is not None, + } + lora_wrapped_child = lora.Linear8bitLt( + child_name, + child.in_features, + child.out_features, + **kwargs, + ) + elif quant_type == QuantType.NF4: + kwargs = { + "compress_statistics": True, + "quant_type": "nf4", + "blocksize": 64, + "bias": hasattr(child, "bias") and child.bias is not None, + } + lora_wrapped_child = lora.Linear4bit( + child_name, + child.in_features, + child.out_features, + **kwargs, + ) + else: + bias = hasattr(child, "bias") and child.bias is not None + lora_wrapped_child = lora.Linear( + child_name, + child.in_features, + child.out_features, + bias=bias, + ) + if lora_wrapped_child: + lora_wrapped_child.active_adapter = None + lora_wrapped_child.weight = child.weight + lora_wrapped_child.bias = child.bias + for p in lora_wrapped_child.parameters(): + p.requires_grad = False + setattr(module, child_name, lora_wrapped_child) + + +def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict): + assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters" + for name, module in block.named_modules(): + for child_name, child in module.named_children(): + if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): + continue + + if child_name in peft_config["target_modules"] or ( + isinstance(peft_config["target_modules"], str) + and re.fullmatch(peft_config["target_modules"], child_name) + ): + is_lora_a_loaded = False + is_lora_b_loaded = False + for peft_key in peft_state_dict: + if peft_key.find(child_name) == -1: + continue + + if adapter_name not in child.lora_A: + child.update_layer( + adapter_name, + peft_config["r"], + peft_config["lora_alpha"], + lora_dropout=peft_config["lora_dropout"], + init_lora_weights=peft_config["init_lora_weights"], + ) + child.train(False) + if peft_config["lora_dropout"] > 0: + logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout") + for p in child.parameters(): + p.requires_grad = False + + if peft_key.endswith(".lora_A.weight"): + child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] + is_lora_a_loaded = True + elif peft_key.endswith(".lora_A.bias"): + raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") + elif peft_key.endswith(".lora_B.weight"): + child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key] + is_lora_b_loaded = True + elif peft_key.endswith(".lora_B.bias"): + raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") + + if is_lora_a_loaded and is_lora_b_loaded: + logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully") diff --git a/tests/test_full_model.py b/tests/test_full_model.py index f2679f2..acd5e6a 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -1,3 +1,4 @@ +import peft import pytest import torch import transformers @@ -12,11 +13,16 @@ logger = get_logger(__name__) @pytest.mark.forked +@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("pass_empty_tensors", (True, False)) -def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): +def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 + MODEL_NAME, + initial_peers=INITIAL_PEERS, + low_cpu_mem_usage=True, + torch_dtype=torch.float32, + active_adapter=ADAPTER_NAME if use_peft else None, ) config = model.config assert isinstance(model, DistributedBloomForCausalLM) @@ -54,6 +60,9 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato ref_model = transformers.BloomForCausalLM.from_pretrained( REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) + if use_peft: + ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME) + ref_model.train(False) if config.vocab_size < ref_model.config.vocab_size: ref_model.resize_token_embeddings(config.vocab_size) logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}") diff --git a/tests/test_peft.py b/tests/test_peft.py new file mode 100644 index 0000000..7ac4f80 --- /dev/null +++ b/tests/test_peft.py @@ -0,0 +1,66 @@ +import os +import shutil + +import pytest +from huggingface_hub import snapshot_download + +from petals.utils.peft import check_peft_repository, load_peft + +UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft" +SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft" +TMP_CACHE_DIR = "tmp_cache/" + + +def clear_dir(path_to_dir): + shutil.rmtree(path_to_dir) + os.mkdir(path_to_dir) + + +def dir_empty(path_to_dir): + files = os.listdir(path_to_dir) + return len(files) == 0 + + +@pytest.mark.forked +def test_check_peft(): + assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load." + assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load." + + +@pytest.mark.forked +def test_load_noncached(tmpdir): + clear_dir(tmpdir) + with pytest.raises(Exception): + load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir) + + assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded" + + load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir) + + assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded" + + +@pytest.mark.forked +def test_load_cached(tmpdir): + clear_dir(tmpdir) + snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir) + + load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir) + + +@pytest.mark.forked +def test_load_layer_exists(tmpdir): + clear_dir(tmpdir) + + load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir) + + +@pytest.mark.forked +def test_load_layer_nonexists(tmpdir): + clear_dir(tmpdir) + + load_peft( + SAFE_PEFT_REPO, + block_idx=1337, + cache_dir=tmpdir, + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ee440d6..e40d235 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,3 +11,5 @@ if not MODEL_NAME: raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested") REF_NAME = os.environ.get("REF_NAME") + +ADAPTER_NAME = os.environ.get("ADAPTER_NAME") From 13f4e3a88aa89b566aaec7f01341bed8e340a34d Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Wed, 12 Jul 2023 15:50:54 +0300 Subject: [PATCH 104/168] Fix convergence issues and switch to LLaMA in the SST-2 example (#343) * Fix convergence issues and switch to LLaMA in the SST-2 example --- .gitignore | 2 + examples/prompt-tuning-sst2.ipynb | 315 ++++++++++-------------------- 2 files changed, 105 insertions(+), 212 deletions(-) diff --git a/.gitignore b/.gitignore index 7114a35..d8c10af 100644 --- a/.gitignore +++ b/.gitignore @@ -126,3 +126,5 @@ dmypy.json # Pyre type checker .pyre/ + +.idea/ diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index c5dac6a..876db8f 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -3,17 +3,19 @@ { "cell_type": "markdown", "id": "a07e0f5e", - "metadata": {}, + "metadata": { + "id": "a07e0f5e" + }, "source": [ "
\n", " \n", "
\n", "\n", - "# Distributed Bloom for Text Classification using Prompt Tuning\n", + "# Distributed LLaMA for Text Classification using Prompt Tuning\n", "\n", - "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", + "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", - "We will adapt BLOOM for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", + "We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", "\n", "To use this notebook in Colab:\n", "\n", @@ -24,7 +26,9 @@ { "cell_type": "markdown", "id": "a3f8526f", - "metadata": {}, + "metadata": { + "id": "a3f8526f" + }, "source": [ "First, we have to prepare all dependencies." ] @@ -33,17 +37,22 @@ "cell_type": "code", "execution_count": null, "id": "73bbc648", - "metadata": {}, + "metadata": { + "id": "73bbc648" + }, "outputs": [], "source": [ - "%pip install -q petals datasets wandb scikit-learn" + "%pip install -q datasets wandb scikit-learn\n", + "%pip install -q git+https://github.com/bigscience-workshop/petals@main" ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", - "metadata": {}, + "metadata": { + "id": "b4ab6ca7" + }, "outputs": [], "source": [ "import os\n", @@ -57,15 +66,19 @@ "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", - "from transformers import BloomTokenizerFast, get_scheduler\n", + "from transformers import LlamaTokenizer, get_scheduler, set_seed\n", "\n", - "from petals import DistributedBloomForSequenceClassification" + "from petals import DistributedLlamaForSequenceClassification\n", + "\n", + "set_seed(0)" ] }, { "cell_type": "markdown", "id": "1bf07b5d", - "metadata": {}, + "metadata": { + "id": "1bf07b5d" + }, "source": [ "Let's set some hyperparameters for training:" ] @@ -74,14 +87,15 @@ "cell_type": "code", "execution_count": null, "id": "f04ba4d2", - "metadata": {}, + "metadata": { + "id": "f04ba4d2" + }, "outputs": [], "source": [ "# Choose a model you'd like to prompt-tune. We recommend starting with\n", - "# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n", - "# Once your code is ready, you can switch to full-scale\n", - "# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n", - "MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n", + "# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n", + "# The code below uses LLaMA-65B.\n", + "MODEL_NAME = \"enoch/llama-65b-hf\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", "# The latter fine-tunes separate prefixes for each transformer block,\n", @@ -89,9 +103,9 @@ "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n", "TUNING_MODE = 'ptune'\n", "\n", - "NUM_PREFIX_TOKENS = 16\n", + "NUM_PREFIX_TOKENS = 8\n", "DEVICE = 'cuda'\n", - "BATCH_SIZE = 16\n", + "BATCH_SIZE = 32\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_EPOCHS = 3\n", @@ -102,32 +116,40 @@ { "cell_type": "markdown", "id": "d38316bd", - "metadata": {}, + "metadata": { + "id": "d38316bd" + }, "source": [ - "Prepare tokenizer and distributed model, connect it to servers." + "Here, we prepare tokenizer and distributed model and connect it to the public swarm." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", - "metadata": {}, + "metadata": { + "id": "03c6e53e" + }, "outputs": [], "source": [ - "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", + "tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", - "model = DistributedBloomForSequenceClassification.from_pretrained(\n", + "tokenizer.pad_token = tokenizer.unk_token\n", + "model = DistributedLlamaForSequenceClassification.from_pretrained(\n", " MODEL_NAME,\n", " pre_seq_len=NUM_PREFIX_TOKENS,\n", " tuning_mode=TUNING_MODE\n", - ").to(DEVICE)" + ").float().to(DEVICE)\n", + "model.config.pad_token_id = tokenizer.pad_token_id" ] }, { "cell_type": "markdown", "id": "042e3786", - "metadata": {}, + "metadata": { + "id": "042e3786" + }, "source": [ "Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset." ] @@ -136,7 +158,9 @@ "cell_type": "code", "execution_count": null, "id": "9c44d516", - "metadata": {}, + "metadata": { + "id": "9c44d516" + }, "outputs": [], "source": [ "task = 'sst2'\n", @@ -144,7 +168,7 @@ "dataset = load_dataset(\"glue\", task)\n", "\n", "def preprocess_function(examples):\n", - " return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True)\n", + " return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True, return_token_type_ids=False)\n", "\n", "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n", @@ -161,16 +185,20 @@ { "cell_type": "markdown", "id": "2a3f3590", - "metadata": {}, + "metadata": { + "id": "2a3f3590" + }, "source": [ - "To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library." + "To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library." ] }, { "cell_type": "code", "execution_count": null, "id": "1e1812be", - "metadata": {}, + "metadata": { + "id": "1e1812be" + }, "outputs": [], "source": [ "metric = load_metric('glue', task)\n", @@ -179,7 +207,7 @@ " model.eval()\n", " for batch in dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", - " \n", + "\n", " with torch.no_grad():\n", " outputs = model(**batch)\n", "\n", @@ -193,16 +221,20 @@ { "cell_type": "markdown", "id": "ef4323fd", - "metadata": {}, + "metadata": { + "id": "ef4323fd" + }, "source": [ - "Before setting up optimizers, check the model parameters that will be trained." + "Before setting up optimizers, let's check the model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", - "metadata": {}, + "metadata": { + "id": "9cc0ba34" + }, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", @@ -213,29 +245,35 @@ { "cell_type": "markdown", "id": "59cffce7", - "metadata": {}, + "metadata": { + "id": "59cffce7" + }, "source": [ - "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler." + "The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", - "metadata": {}, + "metadata": { + "id": "ef9bf344" + }, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", - " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", + " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", - "metadata": {}, + "metadata": { + "id": "423c56d5" + }, "source": [ "Let's initialize wandb for logging and start the training loop!" ] @@ -244,7 +282,9 @@ "cell_type": "code", "execution_count": null, "id": "d9e46807", - "metadata": {}, + "metadata": { + "id": "d9e46807" + }, "outputs": [], "source": [ "wandb.init(\n", @@ -260,20 +300,24 @@ " }\n", ")\n", "\n", + "scaler = torch.cuda.amp.GradScaler()\n", + "\n", "for epoch in range(NUM_EPOCHS):\n", + " model.train()\n", " for batch in tqdm(train_dataloader):\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", - " model.train()\n", - " outputs = model(**batch)\n", + " with torch.autocast(device_type=DEVICE, dtype=torch.float16):\n", + " outputs = model(**batch)\n", " loss = outputs.loss\n", - " loss.backward()\n", + " scaler.scale(loss).backward()\n", "\n", - " optimizer.step()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", - " wandb.log({\"Train Loss\": loss})\n", + " wandb.log({\"Train Loss\": loss.detach()})\n", "\n", " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" @@ -282,184 +326,26 @@ { "cell_type": "markdown", "id": "51770911", - "metadata": {}, - "source": [ - "Our model have been trained!" - ] - }, - { - "cell_type": "markdown", - "id": "1bbf014f", - "metadata": {}, - "source": [ - "## Beyond soft-prompt tuning\n", - "\n", - "Let's try to tune model using adapters in the middle of the model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3bea4391", - "metadata": {}, - "outputs": [], - "source": [ - "class BloomBasedClassifier(nn.Module):\n", - " def __init__(\n", - " self,\n", - " model,\n", - " intermediate_size: int = 32,\n", - " num_classes: int = 2,\n", - " adapter_layer_position: int = 6,\n", - " head_layer_position: int = 10\n", - " ):\n", - " super().__init__()\n", - " self.distributed_layers = model.transformer.h\n", - "\n", - " self.hidden_size = model.config.hidden_size\n", - " self.dtype = model.config.torch_dtype\n", - " self.intermediate_size = intermediate_size\n", - " self.num_classes = num_classes\n", - " self.adapter_layer_position = adapter_layer_position\n", - " self.head_layer_position = head_layer_position\n", - " \n", - " self.word_embeddings = model.transformer.word_embeddings\n", - " self.adapter = nn.Sequential(\n", - " nn.Linear(self.hidden_size, self.intermediate_size),\n", - " nn.Linear(self.intermediate_size, self.hidden_size),\n", - " ).to(self.dtype)\n", - " self.head = nn.Sequential(\n", - " nn.LayerNorm(self.hidden_size),\n", - " nn.Linear(self.hidden_size, self.num_classes),\n", - " ).to(self.dtype)\n", - " \n", - " def forward(self, embeddings):\n", - " before_layers = self.distributed_layers[0:self.adapter_layer_position]\n", - " after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n", - " \n", - " hidden_states = before_layers(embeddings)\n", - " hidden_states = self.adapter(hidden_states)\n", - " hidden_states = after_layers(hidden_states)\n", - " pooled_states = torch.mean(hidden_states, dim=1)\n", - " return self.head(pooled_states)" - ] - }, - { - "cell_type": "markdown", - "id": "15299620", - "metadata": {}, - "source": [ - "Clear model and device memory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa27b168", - "metadata": {}, - "outputs": [], - "source": [ - "del model, optimizer, lr_scheduler\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "markdown", - "id": "5406390f", - "metadata": {}, - "source": [ - "Create new model with adapters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a251db80", - "metadata": {}, - "outputs": [], - "source": [ - "INTERMEDIATE_SIZE = 32\n", - "ADAPTER_LAYER_POSITION = 6\n", - "HEAD_LAYER_POSITION = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3578df3a", - "metadata": {}, - "outputs": [], - "source": [ - "cls_model = BloomBasedClassifier(\n", - " DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n", - " intermediate_size=INTERMEDIATE_SIZE,\n", - " adapter_layer_position=ADAPTER_LAYER_POSITION,\n", - " head_layer_position=HEAD_LAYER_POSITION,\n", - ").to(DEVICE)\n", - "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", - "cls_criterion = nn.CrossEntropyLoss()\n", - "\n", - "lr_scheduler = get_scheduler(\n", - " name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a40468b9", - "metadata": {}, + "metadata": { + "id": "51770911" + }, "source": [ - "And start training our new adapted model." + "Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](http://health.petals.ml/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!" ] }, { "cell_type": "code", "execution_count": null, - "id": "ed051a5d", - "metadata": {}, "outputs": [], - "source": [ - "wandb.init(\n", - " project=\"bloom_based_cls-sst-2\",\n", - " config={\n", - " \"num_epochs\": NUM_EPOCHS,\n", - " \"batch_size\": BATCH_SIZE,\n", - " \"learning_rate\": LR,\n", - " \"weight_decay\": WEIGHT_DECAY,\n", - " \"model_name\": MODEL_NAME,\n", - " \"seed\": SEED,\n", - " \"intermediate_size\": INTERMEDIATE_SIZE,\n", - " \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n", - " \"head_layer_position\": HEAD_LAYER_POSITION,\n", - " }\n", - ")\n", - "\n", - "for epoch in range(NUM_EPOCHS):\n", - " for batch in tqdm(train_dataloader):\n", - " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", - "\n", - " cls_model.train()\n", - " with torch.no_grad():\n", - " embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n", - " outputs = cls_model(embeddings_output)\n", - " loss = cls_criterion(outputs, batch[\"labels\"])\n", - " loss.backward()\n", - "\n", - " cls_optimizer.step()\n", - " lr_scheduler.step()\n", - " cls_optimizer.zero_grad()\n", - "\n", - " wandb.log({\"Train Loss\": loss})\n", - "\n", - " accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n", - " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" - ] + "source": [], + "metadata": { + "collapsed": false + } } ], "metadata": { "kernelspec": { "display_name": "Python 3", - "language": "python", "name": "python3" }, "language_info": { @@ -478,7 +364,12 @@ "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } - } + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 5 From 515a5120cb5a69c1c5e38d7c9b52fcc6c9e27252 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 12 Jul 2023 16:58:58 +0400 Subject: [PATCH 105/168] Mention LLaMA in readme (#344) --- README.md | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index f157cc0..39e9c5f 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,12 @@

-Generate text using distributed 176B-parameter [BLOOM](https://huggingface.co/bigscience/bloom) or [BLOOMZ](https://huggingface.co/bigscience/bloomz) and fine-tune them for your own tasks: +Generate text using distributed [LLaMA-65B](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md), [BLOOM-176B](https://huggingface.co/bigscience/bloom) or [BLOOMZ-176B](https://huggingface.co/bigscience/bloomz) and fine-tune them for your own tasks: ```python -from petals import DistributedBloomForCausalLM +from petals import AutoDistributedModelForCausalLM -model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune", pre_seq_len=16) +model = AutoDistributedModelForCausalLM.from_pretrained("bigscience/bloom", tuning_mode="ptune", pre_seq_len=16) # Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"] @@ -39,7 +39,7 @@ Run our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windo ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ - learningathome/petals:main python -m petals.cli.run_server bigscience/bloom-petals --port 31330 + learningathome/petals:main python -m petals.cli.run_server bigscience/bloom --port 31330 ``` Or run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): @@ -47,13 +47,11 @@ Or run these commands in an [Anaconda](https://www.anaconda.com) env (requires L ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install -U petals -python -m petals.cli.run_server bigscience/bloom-petals +python -m petals.cli.run_server bigscience/bloom ``` 📚 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to configure the server to use multiple GPUs, address common issues, etc. -You can also host [BLOOMZ](https://huggingface.co/bigscience/bloomz), a version of BLOOM fine-tuned to follow human instructions in the zero-shot regime — just replace `bloom-petals` with `bloomz-petals`. - 🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 💬 If you have any issues or feedback, let us know on [our Discord server](https://discord.gg/D9MwApKgWa)! From 294970fe183dc481a3683c6b3a2ea94754f7cb36 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 12 Jul 2023 17:00:15 +0400 Subject: [PATCH 106/168] Update Colab link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 39e9c5f..68564ab 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ for input_ids, labels in data_loader: ```

- 🚀  Try now in Colab + 🚀  Try now in Colab

🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. From 43acfe52a7bfd58e17df97053497a817b50c634a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 12 Jul 2023 23:15:16 +0400 Subject: [PATCH 107/168] Import petals.utils.peft only when needed to avoid unnecessary import of bitsandbytes (#345) The motivation is the same as in #180. --- src/petals/server/backend.py | 7 +++++-- src/petals/utils/convert_block.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 9e81170..f6b9691 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -4,7 +4,6 @@ from collections import Counter from itertools import chain from typing import Any, Dict, Optional, Sequence, Tuple, Union -import peft import torch from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind.moe.expert_uid import ExpertUID @@ -156,9 +155,13 @@ class TransformerBackend(ModuleBackend): def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" + + # Import petals.utils.peft only when necessary to avoid importing bitsandbytes + from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt + adapter_was_loaded = False for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter - if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): + if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)): layer.active_adapter = active_adapter # empty string for no adapter if active_adapter in layer.lora_A.keys(): adapter_was_loaded = True diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index b1c412e..5c04092 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -13,7 +13,6 @@ from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig from petals.utils.misc import QuantType -from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) @@ -56,6 +55,10 @@ def convert_block( shard.to(device) if adapters: + # Import petals.utils.peft only when necessary to avoid importing bitsandbytes + os.environ["BITSANDBYTES_NOWELCOME"] = "1" + from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + create_lora_adapter(block, quant_type=quant_type) for adapter_name in adapters: adapter_config, adapter_state_dict = load_peft( From 90fbaab61e5c2e10b418160e9a36a840f2e7db85 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 13 Jul 2023 19:34:17 +0400 Subject: [PATCH 108/168] Fix Docker build by avoiding Python 3.11 (#348) We want to use `3.10.x` since `grpcio-tools` is not compatible with 3.11 yet. However, `python~=3.10` meant `python>=3.10, python<4.0`, so we ended up with a broken build due to python 3.11 installed. --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index fd03398..b1a2676 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh - bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh ENV PATH="/opt/conda/bin:${PATH}" -RUN conda install python~=3.10 pip && \ +RUN conda install python~=3.10.12 pip && \ pip install --no-cache-dir "torch>=1.12" && \ conda clean --all && rm -rf ~/.cache/pip From f605f093f73a203499cd183073fe73e8380da297 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 14 Jul 2023 00:43:28 +0400 Subject: [PATCH 109/168] Support LLaMA repos without "-hf" suffix (#349) --- src/petals/models/llama/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index dd5f6b1..78443eb 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -33,5 +33,7 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM dht_prefix = str(model_name_or_path) if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :] + if not dht_prefix.endswith("-hf"): + dht_prefix += "-hf" logger.info(f"Using DHT prefix: {dht_prefix}") return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) From 010857a834480dec999042498d505989f829c43e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 14 Jul 2023 01:03:42 +0300 Subject: [PATCH 110/168] Estimate adapter memory overhead in choose_num_blocks() (#346) * estimate adapter memory overhead * reduce number of heads based on that --------- Co-authored-by: Alexander Borzunov --- src/petals/server/server.py | 17 +++++++++++---- src/petals/utils/convert_block.py | 2 -- src/petals/utils/peft.py | 35 ++++++++++++++++++++++++++++--- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 643bf1b..e576d00 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -30,6 +30,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR +from petals.utils.peft import estimate_adapter_memory_per_block from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -176,6 +177,8 @@ class Server: cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 + self.cache_dir = cache_dir + self.adapters = adapters assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: @@ -197,7 +200,6 @@ class Server: self.alloc_timeout = alloc_timeout if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR - self.cache_dir = cache_dir self.max_disk_space = max_disk_space assert isinstance(throughput, float) or throughput in ["auto", "eval"] @@ -219,8 +221,6 @@ class Server: self.mean_balance_check_period = mean_balance_check_period self.mean_block_selection_delay = mean_block_selection_delay - self.adapters = adapters - self.stop = threading.Event() def _choose_num_blocks(self) -> int: @@ -250,7 +250,16 @@ class Server: # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size - num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) + if adapters: + # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes + from petals.utils.peft import estimate_adapter_memory_per_block + + adapter_memory_per_block = estimate_adapter_memory_per_block( + self.block_config, self.torch_dtype, self.adapters, self.cache_dir + ) + total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block + + num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" num_blocks = min(num_blocks, self.block_config.num_hidden_layers) diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 5c04092..b75709d 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -55,8 +55,6 @@ def convert_block( shard.to(device) if adapters: - # Import petals.utils.peft only when necessary to avoid importing bitsandbytes - os.environ["BITSANDBYTES_NOWELCOME"] = "1" from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft create_lora_adapter(block, quant_type=quant_type) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index c551f97..7eb72ef 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,9 +1,16 @@ +import os import re import time -from typing import List, Optional +from typing import List, Optional, Sequence + +os.environ["BITSANDBYTES_NOWELCOME"] = "1" import bitsandbytes as bnb +import peft +import torch import torch.nn as nn +import transformers +from accelerate import init_empty_weights from hivemind.utils.logging import get_logger from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url from peft.tuners import lora @@ -12,6 +19,8 @@ from safetensors import safe_open from safetensors.torch import load_file from transformers.utils import get_file_from_repo +from petals.client.ptune import force_non_empty_weights +from petals.server.block_utils import resolve_block_dtype from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.misc import QuantType @@ -194,15 +203,35 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta p.requires_grad = False if peft_key.endswith(".lora_A.weight"): - child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] + child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_a_loaded = True elif peft_key.endswith(".lora_A.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") elif peft_key.endswith(".lora_B.weight"): - child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key] + child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_b_loaded = True elif peft_key.endswith(".lora_B.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") if is_lora_a_loaded and is_lora_b_loaded: logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully") + + +def estimate_adapter_memory_per_block( + block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs +) -> int: + """Get the number of extra bytes used to store a set of adapters per given block""" + with init_empty_weights(include_buffers=True): + block = block_config.block_class(block_config) + base_block_parameters = sum(p.numel() for p in block.parameters()) + create_lora_adapter(block, quant_type=QuantType.NONE) + + for adapter in adapters: + peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs) + assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now" + add_adapter_to_block( + block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict + ) + adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters + bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8 + return adapter_parameters * bytes_per_parameter From e12d4c666bc34c45e2d0a58e2f03a01fe13a0d97 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 14 Jul 2023 02:52:52 +0400 Subject: [PATCH 111/168] Spam less in server logs (#350) --- src/petals/__init__.py | 2 ++ src/petals/server/backend.py | 17 ++++++++--------- src/petals/server/handler.py | 4 ++-- src/petals/utils/convert_block.py | 1 - src/petals/utils/peft.py | 23 +++++++++++------------ 5 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index f007d11..5658167 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,5 +1,7 @@ import os +os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") + import hivemind import transformers from packaging import version diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index f6b9691..51c6ee0 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -82,14 +82,12 @@ class TransformerBackend(ModuleBackend): def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - if not self.load_adapter_(active_adapter): - raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded") + self.load_adapter_(active_adapter) return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - if not self.load_adapter_(active_adapter): - raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded") + self.load_adapter_(active_adapter) return super().backward(*inputs) @torch.inference_mode() @@ -100,8 +98,7 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - if not self.load_adapter_(inference_info.active_adapter): - raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") + self.load_adapter_(inference_info.active_adapter) with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) @@ -159,13 +156,15 @@ class TransformerBackend(ModuleBackend): # Import petals.utils.peft only when necessary to avoid importing bitsandbytes from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt - adapter_was_loaded = False + loaded = False for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)): layer.active_adapter = active_adapter # empty string for no adapter if active_adapter in layer.lora_A.keys(): - adapter_was_loaded = True - return adapter_was_loaded or not active_adapter + loaded = True + + if active_adapter and not loaded: + raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded") def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d7295ca..d9a5025 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -307,10 +307,10 @@ class TransformerConnectionHandler(ConnectionHandler): """Directly push activation tensors from one server to another""" requested_uids = self._check_uids(request.uid) - self._log_request("rpc_push", requested_uids, context) - metadata = MSGPackSerializer.loads(request.metadata) session_id = metadata["session_id"] + self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") + self._session_queues[session_id].put(request) return runtime_pb2.ExpertResponse() diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index b75709d..dfb5a24 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -71,7 +71,6 @@ def convert_block( def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module: # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes - os.environ["BITSANDBYTES_NOWELCOME"] = "1" import bitsandbytes as bnb for n, module in model.named_children(): diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 7eb72ef..4f51643 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,12 +1,8 @@ -import os import re import time -from typing import List, Optional, Sequence - -os.environ["BITSANDBYTES_NOWELCOME"] = "1" +from typing import Optional, Sequence import bitsandbytes as bnb -import peft import torch import torch.nn as nn import transformers @@ -19,7 +15,6 @@ from safetensors import safe_open from safetensors.torch import load_file from transformers.utils import get_file_from_repo -from petals.client.ptune import force_non_empty_weights from petals.server.block_utils import resolve_block_dtype from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.misc import QuantType @@ -124,7 +119,7 @@ def load_peft( def create_lora_adapter(block, quant_type: QuantType): - for name, module in block.named_modules(): + for _, module in block.named_modules(): for child_name, child in module.named_children(): lora_wrapped_child = None if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)): @@ -173,7 +168,10 @@ def create_lora_adapter(block, quant_type: QuantType): def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict): assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters" - for name, module in block.named_modules(): + if peft_config["lora_dropout"] > 0: + logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout") + + for _, module in block.named_modules(): for child_name, child in module.named_children(): if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): continue @@ -185,7 +183,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta is_lora_a_loaded = False is_lora_b_loaded = False for peft_key in peft_state_dict: - if peft_key.find(child_name) == -1: + if child_name not in peft_key: continue if adapter_name not in child.lora_A: @@ -197,8 +195,6 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta init_lora_weights=peft_config["init_lora_weights"], ) child.train(False) - if peft_config["lora_dropout"] > 0: - logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout") for p in child.parameters(): p.requires_grad = False @@ -214,7 +210,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") if is_lora_a_loaded and is_lora_b_loaded: - logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully") + logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}") + elif is_lora_a_loaded or is_lora_b_loaded: + raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}") + logger.info(f"Loaded adapter {adapter_name} for block {block_index}") def estimate_adapter_memory_per_block( From c511990236173afe33f2b941e59a037a0a62aa5a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 14 Jul 2023 17:05:21 +0300 Subject: [PATCH 112/168] Remove unused import os (#352) --- src/petals/utils/convert_block.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index dfb5a24..299e979 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -1,7 +1,6 @@ """ Tools for converting transformer blocks, applying quantization and/or tensor parallelism """ -import os import re from typing import List, Optional, Sequence From 1a78638c02814e41edbfb18dbbe8a7b3e092ee6d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 14 Jul 2023 18:40:47 +0400 Subject: [PATCH 113/168] Test that bitsandbytes is not imported when it's not used (#351) We avoid importing bitsandbytes when it's not used, since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that. --- setup.cfg | 2 +- src/petals/server/server.py | 1 - tests/test_aux_functions.py | 16 ++++++++++++++++ tests/test_sequence_manager.py | 4 ++-- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index f56a7cc..c6ae594 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ python_requires = >=3.7 install_requires = torch>=1.12 bitsandbytes==0.40.0.post4 - accelerate>=0.16.0,<1.0.0 + accelerate>=0.16.0,<0.21.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 transformers>=4.30.1,<5.0.0 diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e576d00..be580d7 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -30,7 +30,6 @@ from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR -from petals.utils.peft import estimate_adapter_memory_per_block from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 5fa14db..e5b450c 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,3 +1,6 @@ +import subprocess +import sys + import pytest import torch @@ -7,6 +10,19 @@ from petals.utils.convert_block import QuantType from test_utils import MODEL_NAME +def test_bnb_not_imported_when_unnecessary(): + """ + We avoid importing bitsandbytes when it's not used, + since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that. + + If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft + in the function's/method's code when it's actually needed instead of importing them in the beginning of the file. + This won't slow down the code - importing a module for the 2nd time doesn't rerun module code. + """ + + subprocess.check_call([sys.executable, "-c", "import petals, sys; assert 'bitsandbytes' not in sys.modules"]) + + @pytest.mark.forked @pytest.mark.parametrize("tensor_parallel", [False, True]) def test_compute_throughput(tensor_parallel: bool): diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 86d04ca..03e17e3 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -25,7 +25,7 @@ def test_sequence_manager_basics(mode: str): block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] sequential = RemoteSequential( config, - sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt), + sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt), ) sequence = sequential.sequence_manager.make_sequence(mode=mode) @@ -43,7 +43,7 @@ def test_sequence_manager_basics(mode: str): assert shutdown_evt.is_set() -class TestSequenceManager(RemoteSequenceManager): +class RemoteSequenceManagerWithChecks(RemoteSequenceManager): """A sequence manager that signals if it was shut down""" def __init__(self, *args, _was_shut_down: threading.Event, **kwargs): From 9703358df08a6e58ce1f512bb795b482f7181566 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 14 Jul 2023 22:33:48 +0400 Subject: [PATCH 114/168] Fix bugs in _choose_num_blocks() added in #346 (#354) --- src/petals/server/server.py | 23 ++++++++++++++--------- src/petals/utils/peft.py | 7 +++++-- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index be580d7..c90ae44 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -174,9 +174,13 @@ class Server: self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") + # For attention cache in GPU or RAM cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 + + # For disk cache self.cache_dir = cache_dir + self.max_disk_space = max_disk_space self.adapters = adapters assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" @@ -197,9 +201,6 @@ class Server: logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") self.alloc_timeout = alloc_timeout - if cache_dir is None: - cache_dir = DEFAULT_CACHE_DIR - self.max_disk_space = max_disk_space assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: @@ -243,20 +244,24 @@ class Server: else: total_memory = torch.cuda.get_device_properties(self.device).total_memory - block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type) - gib = 1024**3 # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size - if adapters: + block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type) + total_memory_per_block = block_size + self._cache_bytes_per_block + if self.adapters: # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes from petals.utils.peft import estimate_adapter_memory_per_block - adapter_memory_per_block = estimate_adapter_memory_per_block( - self.block_config, self.torch_dtype, self.adapters, self.cache_dir + total_memory_per_block += estimate_adapter_memory_per_block( + self.block_config, + self.torch_dtype, + self.adapters, + use_auth_token=self.use_auth_token, + cache_dir=self.cache_dir, + max_disk_space=self.max_disk_space, ) - total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 4f51643..c537a32 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -217,7 +217,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta def estimate_adapter_memory_per_block( - block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs + block_config: transformers.PretrainedConfig, + torch_dtype: Optional[torch.dtype], + adapters: Sequence[str], + **load_peft_kwargs, ) -> int: """Get the number of extra bytes used to store a set of adapters per given block""" with init_empty_weights(include_buffers=True): @@ -226,7 +229,7 @@ def estimate_adapter_memory_per_block( create_lora_adapter(block, quant_type=QuantType.NONE) for adapter in adapters: - peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs) + peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs) assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now" add_adapter_to_block( block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict From 37fdcb3fe066a45ae80c3419cc60c658cbcbb594 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 14 Jul 2023 22:04:55 +0300 Subject: [PATCH 115/168] Switch adapters slightly faster (#353) Currently, each `TransformerBackend.inference_step` looks for adapters and sets the correct adapter type for each block. This is not very expensive, but it can measurably affect inference time. This pull request uses faster adapter switching with just one variable assignment, without iterating over block.modules(). --- src/petals/server/backend.py | 35 +++++++++-------------- src/petals/server/handler.py | 18 ++++++++---- src/petals/server/server.py | 1 + src/petals/utils/peft.py | 55 +++++++++++++++++++++++++++++++----- 4 files changed, 75 insertions(+), 34 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 51c6ee0..4220546 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -24,9 +24,15 @@ logger = get_logger(__name__) class TransformerBackend(ModuleBackend): """A wrapper for a transformer block that can process requests for forward, backward and inference""" + _peft_module = None + def __init__( self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs ): + import petals.utils.peft as _peft_module + + self._peft_module = _peft_module + super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) self.config = config @@ -82,13 +88,13 @@ class TransformerBackend(ModuleBackend): def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().forward(*inputs) + with self._peft_module.using_adapter(active_adapter): + return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().backward(*inputs) + with self._peft_module.using_adapter(active_adapter): + return super().backward(*inputs) @torch.inference_mode() def inference_step( @@ -98,8 +104,9 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - self.load_adapter_(inference_info.active_adapter) - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + with self.memory_cache.use_cache( + *inference_info.cache_handles + ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) @@ -150,22 +157,6 @@ class TransformerBackend(ModuleBackend): for p in self.module.parameters(): p.data = dummy - def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: - """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" - - # Import petals.utils.peft only when necessary to avoid importing bitsandbytes - from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt - - loaded = False - for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter - if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)): - layer.active_adapter = active_adapter # empty string for no adapter - if active_adapter in layer.lora_A.keys(): - loaded = True - - if active_adapter and not loaded: - raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded") - def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d9a5025..12fd6eb 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -68,6 +68,7 @@ class TransformerConnectionHandler(ConnectionHandler): dht: DHT, module_backends: Dict[str, TransformerBackend], *, + adapters: Optional[Sequence[str]], dht_prefix: str, push_manager: multiprocessing.managers.SyncManager, session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue @@ -81,6 +82,7 @@ class TransformerConnectionHandler(ConnectionHandler): for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix + self.adapters = adapters self._push_manager = push_manager self._session_queues = session_queues self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues @@ -141,7 +143,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") @@ -355,7 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -382,7 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -433,7 +435,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -458,7 +460,7 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -476,6 +478,12 @@ class TransformerConnectionHandler(ConnectionHandler): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) + def _get_active_adapter(self, metadata: dict) -> str: + active_adapter = metadata.get("active_adapter", "") + if active_adapter and (active_adapter not in self.adapters): + raise KeyError(f"adapter {active_adapter} not found") + return active_adapter + def _serialize_grads( self, grads: Sequence[torch.Tensor], diff --git a/src/petals/server/server.py b/src/petals/server/server.py index c90ae44..83a94e3 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -534,6 +534,7 @@ class ModuleContainer(threading.Thread): TransformerConnectionHandler( dht, self.module_backends, + adapters=adapters, dht_prefix=dht_prefix, push_manager=self.push_manager, session_queues=session_queues, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index c537a32..b182181 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,3 +1,4 @@ +import contextlib import re import time from typing import Optional, Sequence @@ -118,6 +119,47 @@ def load_peft( time.sleep(delay) +class AdapterContextMixin: + """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" + + ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" + _context_active_adapter = ADAPTER_NOT_SET + + @staticmethod + @contextlib.contextmanager + def using_adapter(active_adapter: Optional[str]): + prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter + try: + yield + finally: + AdapterContextMixin._context_active_adapter = prev + + @property + def active_adapter(self): + if self._context_active_adapter == self.ADAPTER_NOT_SET: + logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug") + return self._context_active_adapter + + @active_adapter.setter + def active_adapter(self, value: Optional[str]): + assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" "" + + +using_adapter = AdapterContextMixin.using_adapter + + +class LoraLinear(lora.Linear, AdapterContextMixin): + """LoRA linear layer that uses adapter selected via using_adapter""" + + +class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin): + """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter""" + + +class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin): + """LoRA linear 4-bit that uses adapter selected via using_adapter""" + + def create_lora_adapter(block, quant_type: QuantType): for _, module in block.named_modules(): for child_name, child in module.named_children(): @@ -130,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType): "threshold": 6.0, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear8bitLt( - child_name, + lora_wrapped_child = LoraLinear8bitLt( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, @@ -143,22 +185,21 @@ def create_lora_adapter(block, quant_type: QuantType): "blocksize": 64, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear4bit( - child_name, + lora_wrapped_child = LoraLinear4bit( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, ) else: bias = hasattr(child, "bias") and child.bias is not None - lora_wrapped_child = lora.Linear( - child_name, + lora_wrapped_child = LoraLinear( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, bias=bias, ) if lora_wrapped_child: - lora_wrapped_child.active_adapter = None lora_wrapped_child.weight = child.weight lora_wrapped_child.bias = child.bias for p in lora_wrapped_child.parameters(): From 2c8959e713754761fd8593d90acf913dcb9b2914 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 15 Jul 2023 03:36:31 +0400 Subject: [PATCH 116/168] Share more info about a server in DHT (#355) --- setup.cfg | 2 +- src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 5 +- src/petals/data_structures.py | 33 ++++++++++--- src/petals/dht_utils.py | 41 +++++----------- src/petals/models/bloom/config.py | 2 - src/petals/models/llama/config.py | 4 +- src/petals/server/handler.py | 3 +- src/petals/server/memory_cache.py | 4 ++ src/petals/server/server.py | 82 +++++++++++++++++-------------- src/petals/utils/convert_block.py | 4 +- 11 files changed, 95 insertions(+), 87 deletions(-) diff --git a/setup.cfg b/setup.cfg index c6ae594..0053628 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ install_requires = tokenizers>=0.13.3 transformers>=4.30.1,<5.0.0 speedtest-cli==2.1.3 - pydantic>=1.8.1,<2.0 # 2.0 is incompatible with hivemind==1.1.8 + pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8 hivemind==1.1.8 tensor_parallel==1.0.23 humanfriendly diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 5658167..3e67633 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev1" +__version__ = "1.2.0.dev2" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 6b3fde8..b2480f5 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -146,8 +146,9 @@ def main(): help="Skip checking this server's reachability via health.petals.ml " "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") - - parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.") + + parser.add_argument("--adapters", nargs='+', default=(), + help="List of pre-loaded LoRA adapters that can be used for inference or training") # fmt:on args = vars(parser.parse_args()) diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 254faae..9e13ebe 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -1,10 +1,8 @@ -from __future__ import annotations - import dataclasses -from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple +import pydantic from hivemind import PeerID from hivemind.moe.expert_uid import ExpertUID @@ -21,13 +19,32 @@ class ServerState(Enum): ONLINE = 2 -@dataclass +@pydantic.dataclasses.dataclass class ServerInfo: state: ServerState - throughput: float + throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + + adapters: Sequence[str] = () + version: Optional[str] = None + torch_dtype: Optional[str] = None + quant_type: Optional[str] = None + using_relay: Optional[bool] = None + cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None + + def to_tuple(self) -> Tuple[int, float, dict]: + extra_info = dataclasses.asdict(self) + del extra_info["state"], extra_info["throughput"] + return (self.state.value, self.throughput, extra_info) + + @classmethod + def from_tuple(cls, source: tuple): + state, throughput = source[:2] + extra_info = source[2] if len(source) > 2 else {} + # pydantic will validate existing fields and ignore extra ones + return cls(state=ServerState(state), throughput=throughput, **extra_info) -@dataclass +@dataclasses.dataclass class RemoteModuleInfo: """A remote module that is served by one or more servers""" @@ -35,7 +52,7 @@ class RemoteModuleInfo: servers: Dict[PeerID, ServerInfo] -@dataclass +@dataclasses.dataclass class RemoteSpanInfo: """A chain of remote blocks served by one specific remote peer""" diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 99316f2..0710f60 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -11,7 +11,7 @@ from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo logger = get_logger(__name__) @@ -19,10 +19,8 @@ logger = get_logger(__name__) def declare_active_modules( dht: DHT, uids: Sequence[ModuleUID], + server_info: ServerInfo, expiration_time: DHTExpiration, - state: ServerState, - throughput: float, - adapters: Optional[Sequence[str]] = None, wait: bool = True, ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: """ @@ -42,14 +40,7 @@ def declare_active_modules( assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid return dht.run_coroutine( - partial( - _declare_active_modules, - uids=uids, - expiration_time=expiration_time, - state=state, - throughput=throughput, - adapters=list(adapters or []), - ), + partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time), return_future=not wait, ) @@ -58,16 +49,14 @@ async def _declare_active_modules( dht: DHT, node: DHTNode, uids: List[ModuleUID], + server_info: ServerInfo, expiration_time: DHTExpiration, - state: ServerState, - throughput: float, - adapters: List[str], ) -> Dict[ModuleUID, bool]: num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) return await node.store_many( keys=uids, subkeys=[dht.peer_id.to_base58()] * len(uids), - values=[(state.value, throughput, dict(adapters=adapters))] * len(uids), + values=[server_info.to_tuple()] * len(uids), expiration_time=expiration_time, num_workers=num_workers, ) @@ -115,29 +104,21 @@ async def _get_remote_module_infos( metadata = found[uid] if metadata is None or not isinstance(metadata.value, dict): if metadata is not None: - logger.error(f"Incorrect metadata for {uid}: {metadata}") + logger.warning(f"Incorrect metadata for {uid}: {metadata}") continue servers = {} for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) - state, throughput = server_info.value[:2] - extra_info = server_info.value[2] if len(server_info.value) > 2 else {} - adapters = extra_info.get("adapters", []) - if bool(active_adapter) and active_adapter not in adapters: + server_info = ServerInfo.from_tuple(server_info.value) + + if active_adapter and active_adapter not in server_info.adapters: logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") continue - if not ( - isinstance(state, int) - and isinstance(throughput, float) - and math.isfinite(throughput) - and throughput >= 0.0 - ): - raise ValueError(f"Invalid server info: {server_info}") - servers[peer_id] = ServerInfo(ServerState(state), throughput) + servers[peer_id] = server_info except (TypeError, ValueError) as e: - logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") + logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") if servers: modules[i] = RemoteModuleInfo(uid, servers) return modules diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py index d6a8146..23521fc 100644 --- a/src/petals/models/bloom/config.py +++ b/src/petals/models/bloom/config.py @@ -9,8 +9,6 @@ from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.models.bloom.block import WrappedBloomBlock -from petals.utils.auto_config import AutoDistributedConfig -from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index 78443eb..b21fa9a 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -9,7 +9,6 @@ from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.models.llama.block import WrappedLlamaBlock -from petals.utils.auto_config import AutoDistributedConfig logger = get_logger(__name__) @@ -31,8 +30,7 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) - if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts - dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :] + dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts if not dht_prefix.endswith("-hf"): dht_prefix += "-hf" logger.info(f"Using DHT prefix: {dht_prefix}") diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 12fd6eb..d0531de 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -562,11 +562,10 @@ class TransformerConnectionHandler(ConnectionHandler): """Return metadata about stored block uids and current load""" backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values())) - cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes) result = { "version": petals.__version__, "dht_client_mode": self.dht.client_mode, - CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()), + CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()), } if request.uid: diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 7f00bae..a1e2f26 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -47,6 +47,10 @@ class MemoryCache: def current_size_bytes(self, value: int): self._current_size.value = value + @property + def bytes_left(self) -> int: + return self.max_size_bytes - self.current_size_bytes + @property def handle_counter(self) -> int: return self._handle_counter.value diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 83a94e3..bac93c5 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -16,8 +16,9 @@ from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig +import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace @@ -29,7 +30,6 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block -from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -81,7 +81,7 @@ class Server: dht_client_mode: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, - adapters: Optional[List[str]] = None, + adapters: Sequence[str] = (), **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -215,7 +215,15 @@ class Server: force_eval=(throughput == "eval"), cache_dir=cache_dir, ) - self.throughput = throughput + self.server_info = ServerInfo( + state=ServerState.JOINING, + throughput=throughput, + adapters=tuple(adapters), + version=petals.__version__, + torch_dtype=str(torch_dtype).lstrip("torch."), + quant_type=quant_type.name.lower(), + using_relay=self.dht.client_mode, + ) self.balance_quality = balance_quality self.mean_balance_check_period = mean_balance_check_period @@ -283,7 +291,7 @@ class Server: block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, alloc_timeout=self.alloc_timeout, - throughput=self.throughput, + server_info=self.server_info, block_indices=block_indices, num_handlers=self.num_handlers, min_batch_size=self.min_batch_size, @@ -307,7 +315,6 @@ class Server: quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, - adapters=self.adapters, start=True, ) try: @@ -385,7 +392,7 @@ class ModuleContainer(threading.Thread): block_config: PretrainedConfig, attn_cache_bytes: int, alloc_timeout: float, - throughput: float, + server_info: ServerInfo, block_indices: List[int], min_batch_size: int, max_batch_size: int, @@ -401,16 +408,18 @@ class ModuleContainer(threading.Thread): quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, - adapters: Optional[List[str]] = None, **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] + memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) + + server_info.state = ServerState.JOINING joining_announcer = ModuleAnnouncerThread( module_uids, dht, - ServerState.JOINING, - adapters=adapters, - throughput=throughput, + server_info, + block_config=block_config, + memory_cache=memory_cache, update_period=update_period, expiration=expiration, daemon=True, @@ -420,7 +429,6 @@ class ModuleContainer(threading.Thread): assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) - memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) blocks = {} try: for module_uid, block_index in zip(module_uids, block_indices): @@ -441,7 +449,7 @@ class ModuleContainer(threading.Thread): tensor_parallel_devices, device, quant_type, - adapters=adapters, + adapters=server_info.adapters, freeze=True, use_auth_token=use_auth_token, cache_dir=cache_dir, @@ -477,13 +485,12 @@ class ModuleContainer(threading.Thread): joining_announcer.stop.set() joining_announcer.join() + server_info.state = ServerState.OFFLINE declare_active_modules( dht, module_uids, + server_info, expiration_time=get_dht_time() + expiration, - state=ServerState.OFFLINE, - throughput=throughput, - adapters=adapters, ) logger.info(f"Announced that blocks {module_uids} are offline") raise @@ -497,8 +504,9 @@ class ModuleContainer(threading.Thread): dht, dht_prefix, blocks, - adapters=adapters, - throughput=throughput, + block_config=block_config, + memory_cache=memory_cache, + server_info=server_info, update_period=update_period, expiration=expiration, **kwargs, @@ -510,10 +518,11 @@ class ModuleContainer(threading.Thread): dht_prefix: str, module_backends: Dict[str, TransformerBackend], *, + block_config: PretrainedConfig, + memory_cache: MemoryCache, inference_max_length: int, num_handlers: int, - throughput: float, - adapters: Optional[Sequence[str]], + server_info: ServerInfo, update_period: float, expiration: Optional[float] = None, request_timeout: float, @@ -525,7 +534,7 @@ class ModuleContainer(threading.Thread): super().__init__() self.dht, self.module_backends = dht, module_backends - self.throughput, self.update_period, self.expiration = throughput, update_period, expiration + self.server_info, self.update_period, self.expiration = server_info, update_period, expiration self.push_manager = mp.Manager() self.push_manager.__enter__() @@ -534,7 +543,7 @@ class ModuleContainer(threading.Thread): TransformerConnectionHandler( dht, self.module_backends, - adapters=adapters, + adapters=server_info.adapters, dht_prefix=dht_prefix, push_manager=self.push_manager, session_queues=session_queues, @@ -548,12 +557,14 @@ class ModuleContainer(threading.Thread): self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. + + self.server_info.state = ServerState.ONLINE self.online_announcer = ModuleAnnouncerThread( list(self.module_backends.keys()), dht, - ServerState.ONLINE, - adapters=adapters, - throughput=throughput, + self.server_info, + block_config=block_config, + memory_cache=memory_cache, update_period=update_period, expiration=expiration, daemon=True, @@ -613,12 +624,12 @@ class ModuleContainer(threading.Thread): self.online_announcer.stop.set() self.online_announcer.join() + self.server_info.state = ServerState.OFFLINE declare_active_modules( self.dht, self.module_backends.keys(), + self.server_info, expiration_time=get_dht_time() + self.expiration, - state=ServerState.OFFLINE, - throughput=self.throughput, ) logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline") @@ -651,10 +662,10 @@ class ModuleAnnouncerThread(threading.Thread): self, module_uids: List[str], dht: DHT, - state: ServerState, - adapters: Optional[Sequence[str]], + server_info: ServerInfo, *, - throughput: float, + block_config: PretrainedConfig, + memory_cache: MemoryCache, update_period: float = 30, expiration: float, **kwargs, @@ -662,22 +673,21 @@ class ModuleAnnouncerThread(threading.Thread): super().__init__(**kwargs) self.module_uids = module_uids self.dht = dht - self.state = state - self.adapters = adapters - self.throughput = throughput + self.server_info = server_info + self.memory_cache = memory_cache + self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8 self.update_period = update_period self.expiration = expiration self.stop = threading.Event() def run(self) -> None: while True: + self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token declare_active_modules( self.dht, self.module_uids, + self.server_info, expiration_time=get_dht_time() + self.expiration, - state=self.state, - throughput=self.throughput, - adapters=self.adapters, ) if self.stop.wait(self.update_period): break diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 299e979..f8a4637 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -2,7 +2,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor parallelism """ import re -from typing import List, Optional, Sequence +from typing import Optional, Sequence import tensor_parallel as tp import torch @@ -25,7 +25,7 @@ def convert_block( output_device: torch.device, quant_type: QuantType, freeze: bool = True, - adapters: Optional[List[str]] = None, + adapters: Optional[Sequence[str]] = None, **kwargs, ) -> tp.TensorParallel: """ From 81c4a45ca2e50d77eaaff3d8fd66856f14d081c2 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 15 Jul 2023 20:16:21 +0400 Subject: [PATCH 117/168] Make a server ping next servers (#356) This PR makes a server ping potential next servers in a chain and report the RTTs to DHT. This will be used for shortest-path routing. --- src/petals/cli/run_server.py | 2 +- src/petals/data_structures.py | 1 + src/petals/server/server.py | 110 ++++++++++++++++++---------------- src/petals/utils/ping.py | 60 +++++++++++++++++++ 4 files changed, 120 insertions(+), 53 deletions(-) create mode 100644 src/petals/utils/ping.py diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index b2480f5..f2a0168 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -98,7 +98,7 @@ def main(): 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') - parser.add_argument('--update_period', type=float, required=False, default=150, + parser.add_argument('--update_period', type=float, required=False, default=60, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, help='DHT entries will expire after this many seconds') diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 9e13ebe..8d7d50b 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -30,6 +30,7 @@ class ServerInfo: quant_type: Optional[str] = None using_relay: Optional[bool] = None cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None + next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None def to_tuple(self) -> Tuple[int, float, dict]: extra_info = dataclasses.asdict(self) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index bac93c5..f09724f 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -8,6 +8,7 @@ import threading import time from typing import Dict, List, Optional, Sequence, Union +import hivemind import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -30,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block +from petals.utils.ping import PingAggregator from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -64,7 +66,7 @@ class Server: compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None, - update_period: float = 150, + update_period: float = 60, expiration: Optional[float] = None, request_timeout: float = 3 * 60, session_timeout: float = 30 * 60, @@ -220,7 +222,7 @@ class Server: throughput=throughput, adapters=tuple(adapters), version=petals.__version__, - torch_dtype=str(torch_dtype).lstrip("torch."), + torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), using_relay=self.dht.client_mode, ) @@ -413,8 +415,8 @@ class ModuleContainer(threading.Thread): module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) - server_info.state = ServerState.JOINING - joining_announcer = ModuleAnnouncerThread( + assert server_info.state == ServerState.JOINING + dht_announcer = ModuleAnnouncerThread( module_uids, dht, server_info, @@ -424,7 +426,7 @@ class ModuleContainer(threading.Thread): expiration=expiration, daemon=True, ) - joining_announcer.start() + dht_announcer.start() logger.info(f"Announced that blocks {block_indices} are joining") assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) @@ -476,6 +478,8 @@ class ModuleContainer(threading.Thread): max_batch_size=max_batch_size, ) + merge_inference_pools_inplace(blocks) + if should_validate_reachability: validate_reachability(dht.peer_id) except: @@ -483,29 +487,15 @@ class ModuleContainer(threading.Thread): for backend in blocks.values(): backend.shutdown() - joining_announcer.stop.set() - joining_announcer.join() - server_info.state = ServerState.OFFLINE - declare_active_modules( - dht, - module_uids, - server_info, - expiration_time=get_dht_time() + expiration, - ) + dht_announcer.announce(ServerState.OFFLINE) logger.info(f"Announced that blocks {module_uids} are offline") raise - else: - joining_announcer.stop.set() - joining_announcer.join() - - merge_inference_pools_inplace(blocks) return cls( dht, dht_prefix, blocks, - block_config=block_config, - memory_cache=memory_cache, + dht_announcer=dht_announcer, server_info=server_info, update_period=update_period, expiration=expiration, @@ -518,10 +508,9 @@ class ModuleContainer(threading.Thread): dht_prefix: str, module_backends: Dict[str, TransformerBackend], *, - block_config: PretrainedConfig, - memory_cache: MemoryCache, inference_max_length: int, num_handlers: int, + dht_announcer: ModuleAnnouncerThread, server_info: ServerInfo, update_period: float, expiration: Optional[float] = None, @@ -558,17 +547,8 @@ class ModuleContainer(threading.Thread): self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. - self.server_info.state = ServerState.ONLINE - self.online_announcer = ModuleAnnouncerThread( - list(self.module_backends.keys()), - dht, - self.server_info, - block_config=block_config, - memory_cache=memory_cache, - update_period=update_period, - expiration=expiration, - daemon=True, - ) + dht_announcer.announce(ServerState.ONLINE) + self.dht_announcer = dht_announcer if start: self.run_in_background(await_ready=True) @@ -578,11 +558,6 @@ class ModuleContainer(threading.Thread): Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers, runs Runtime (self.runtime) to process incoming requests. """ - if not self.dht.is_alive(): - self.dht.run_in_background(await_ready=True) - - self.online_announcer.start() - for handler in self.conn_handlers: handler.run_in_background() @@ -621,16 +596,7 @@ class ModuleContainer(threading.Thread): Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL). """ - self.online_announcer.stop.set() - self.online_announcer.join() - - self.server_info.state = ServerState.OFFLINE - declare_active_modules( - self.dht, - self.module_backends.keys(), - self.server_info, - expiration_time=get_dht_time() + self.expiration, - ) + self.dht_announcer.announce(ServerState.OFFLINE) logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline") self.ready.clear() @@ -666,8 +632,10 @@ class ModuleAnnouncerThread(threading.Thread): *, block_config: PretrainedConfig, memory_cache: MemoryCache, - update_period: float = 30, + update_period: float, expiration: float, + max_pinged: int = 5, + max_reported: int = 10, **kwargs, ): super().__init__(**kwargs) @@ -678,20 +646,58 @@ class ModuleAnnouncerThread(threading.Thread): self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8 self.update_period = update_period self.expiration = expiration - self.stop = threading.Event() + self.trigger = threading.Event() + + self.max_pinged, self.max_reported = max_pinged, max_reported + last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1])) + dht_prefix, block_index = last_uid.split(UID_DELIMITER) + self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}" + self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: while True: + start_time = time.perf_counter() + self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token + if self.server_info.state != ServerState.OFFLINE: + self._ping_next_servers() + self.server_info.next_pings = { + peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items() + } + else: + self.server_info.next_pings = None # No need to ping if we're disconnecting + declare_active_modules( self.dht, self.module_uids, self.server_info, expiration_time=get_dht_time() + self.expiration, ) - if self.stop.wait(self.update_period): + if self.server_info.state == ServerState.OFFLINE: break + delay = self.update_period - (time.perf_counter() - start_time) + if delay < 0: + logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it") + self.trigger.wait(max(delay, 0)) + self.trigger.clear() + + def announce(self, state: ServerState) -> None: + self.server_info.state = state + self.trigger.set() + if state == ServerState.OFFLINE: + self.join() + + def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: + [module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True) + if module_info is None: + return + + next_servers = list(module_info.servers) + if len(next_servers) > self.max_pinged: + next_servers = random.sample(next_servers, self.max_pinged) + self.ping_aggregator.ping(next_servers) + class RuntimeWithDeduplicatedPools(Runtime): """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool""" diff --git a/src/petals/utils/ping.py b/src/petals/utils/ping.py new file mode 100644 index 0000000..de675a5 --- /dev/null +++ b/src/petals/utils/ping.py @@ -0,0 +1,60 @@ +import asyncio +import math +import time +from functools import partial +from typing import Dict, Sequence + +import hivemind +from hivemind.proto import dht_pb2 +from hivemind.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def ping( + peer_id: hivemind.PeerID, + _dht: hivemind.DHT, + node: hivemind.dht.DHTNode, + *, + wait_timeout: float = 1, +) -> float: + try: + ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info) + start_time = time.perf_counter() + await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout) + return time.perf_counter() - start_time + except Exception: + logger.debug(f"Failed to ping {peer_id}:", exc_info=True) + return math.inf + + +async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]: + rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids]) + return dict(zip(peer_ids, rpc_infos)) + + +class PingAggregator: + def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600): + self.dht = dht + self.ema_alpha = ema_alpha + self.expiration = expiration + self.ping_emas = hivemind.TimedStorage() + + def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs): + current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs)) + logger.debug(f"Current RTTs: {current_rtts}") + + expiration = hivemind.get_dht_time() + self.expiration + for peer_id, rtt in current_rtts.items(): + prev_rtt = self.ping_emas.get(peer_id) + if prev_rtt is not None and prev_rtt.value != math.inf: + rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing + self.ping_emas.store(peer_id, rtt, expiration) + + def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]: + with self.ping_emas.freeze(): + smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()} + logger.debug(f"Smothed RTTs: {smoothed_rtts}") + + fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers] + return dict(fastest_rtts) From 3f733a96e37d34d6b823f825e5f0f67ad3e916c3 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 16 Jul 2023 03:07:21 +0400 Subject: [PATCH 118/168] Use bitsandbytes 0.40.1.post1 (#357) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 0053628..7fc930e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.40.0.post4 + bitsandbytes==0.40.1.post1 accelerate>=0.16.0,<0.21.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 From 9517dd1e3d2bc5ad42c9f9d8c535d714cda6175a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 17 Jul 2023 05:02:08 +0400 Subject: [PATCH 119/168] Update readme and "Getting started" link (#360) This updates readme with the latest updates and fixes an old Colab link, as pointed out in #359. --- README.md | 52 +++++++++++++++++++++++----------------------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 68564ab..5c3a4c5 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,24 @@


- Run 100B+ language models at home, BitTorrent-style.
+ Run large language models at home, BitTorrent-style.
Fine-tuning and inference up to 10x faster than offloading


-Generate text using distributed [LLaMA-65B](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md), [BLOOM-176B](https://huggingface.co/bigscience/bloom) or [BLOOMZ-176B](https://huggingface.co/bigscience/bloomz) and fine-tune them for your own tasks: +Generate text with distributed [LLaMA-65B](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md), [Guanaco](https://huggingface.co/timdettmers/guanaco-65b), [BLOOM-176B](https://huggingface.co/bigscience/bloom), or [BLOOMZ](https://huggingface.co/bigscience/bloomz) and fine-tune them for your own tasks — right from your desktop computer or Google Colab: ```python +from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM -model = AutoDistributedModelForCausalLM.from_pretrained("bigscience/bloom", tuning_mode="ptune", pre_seq_len=16) -# Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet +model_name = "bigscience/bloom" # You can use any Hugging Face hub repo with a supported model +tokenizer = AutoTokenizer(model_name) +model = AutoDistributedModelForCausalLM.from_pretrained(model_name) +# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"] outputs = model.generate(inputs, max_new_tokens=5) print(tokenizer.decode(outputs[0])) # A cat sat on a mat... - -# Fine-tuning (updates only prompts or adapters hosted locally) -optimizer = torch.optim.AdamW(model.parameters()) -for input_ids, labels in data_loader: - outputs = model.forward(input_ids) - loss = cross_entropy(outputs.logits, labels) - optimizer.zero_grad() - loss.backward() - optimizer.step() ```

@@ -33,40 +27,42 @@ for input_ids, labels in data_loader: 🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. +📋 Make sure you follow the model's terms of use (see [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license) licenses). Note that LLaMA is available for non-commercial purposes only, and you have to file a request [here](https://bit.ly/llama-license) to use it in your own projects. + ### Connect your GPU and increase Petals capacity -Run our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)): +Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): ```bash -sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ - learningathome/petals:main python -m petals.cli.run_server bigscience/bloom --port 31330 +conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia +pip install git+https://github.com/bigscience-workshop/petals +python -m petals.cli.run_server bigscience/bloom ``` -Or run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): +Or run our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)): ```bash -conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install -U petals -python -m petals.cli.run_server bigscience/bloom +sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ + learningathome/petals:main python -m petals.cli.run_server bigscience/bloom --port 31330 ``` -📚 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to configure the server to use multiple GPUs, address common issues, etc. - 🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). +📚 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to configure the server to use multiple GPUs, address common issues, etc. + 💬 If you have any issues or feedback, let us know on [our Discord server](https://discord.gg/D9MwApKgWa)! ### Check out tutorials, examples, and more Basic tutorials: -- Getting started: [tutorial](https://colab.research.google.com/drive/1Ervk6HPNS6AYVr3xVdQnY5a-TjjmLCdQ?usp=sharing) +- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing) +- Prompt-tune LLaMA-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) - Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) -- Prompt-tune BLOOM for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) Useful tools and advanced guides: -- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP endpoint): [source code](https://github.com/borzunov/chat.petals.ml) +- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/borzunov/chat.petals.ml) - [Monitor](http://health.petals.ml) for the public swarm: [source code](https://github.com/borzunov/health.petals.ml) - Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) @@ -76,12 +72,10 @@ Learning more: - Frequently asked questions: [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions) - In-depth system description: [paper](https://arxiv.org/abs/2209.01188) -📋 If you build an app running BLOOM with Petals, make sure it follows the BLOOM's [terms of use](https://huggingface.co/bigscience/bloom). - ## How does it work? -- Petals runs large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at ≈ 1 sec per step (token) — [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](http://chat.petals.ml) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. +- Petals runs large language models like [LLaMA-65B](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. +- Single-batch inference runs at 3-4 steps/sec for LLaMA-65B and ≈ 1 step/sec for BLOOM-176B — [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](http://chat.petals.ml) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.

From 11f0d992d7af89db33b5f7882a3e9fc5214cc4fc Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 17 Jul 2023 13:45:59 +0400 Subject: [PATCH 120/168] Report inference, forward, and network RPS separately (#358) Inference RPS may be very different from forward RPS. E.g., currently bnb uses a completely different algorithm for NF4 inference. We report detailed RPS info that can be then used for shortest-path routing for inference. --- src/petals/client/routing/sequence_info.py | 9 ++- src/petals/client/routing/sequence_manager.py | 2 +- src/petals/data_structures.py | 11 +++- src/petals/server/server.py | 6 +- src/petals/server/throughput.py | 66 +++++++++++++------ src/petals/utils/ping.py | 2 +- tests/test_aux_functions.py | 8 ++- 7 files changed, 72 insertions(+), 32 deletions(-) diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index b35b02b..bce6712 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -73,12 +73,15 @@ class RemoteSequenceInfo: active_spans = {} for block_index, info in enumerate(block_infos): if info is not None: - for peer_id, server in info.servers.items(): - if server.state != ServerState.ONLINE: + for peer_id, server_info in info.servers.items(): + if server_info.state != ServerState.ONLINE: continue if peer_id not in active_spans: active_spans[peer_id] = RemoteSpanInfo( - peer_id=peer_id, start=block_index, end=block_index + 1, throughput=server.throughput + peer_id=peer_id, + start=block_index, + end=block_index + 1, + server_info=server_info, ) else: # peer_id in active_spans active_spans[peer_id].end = block_index + 1 diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index fc505cc..19b475b 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -151,7 +151,7 @@ class RemoteSequenceManager: raise MissingBlocksError(current_index) if mode == "max_throughput": - span_weights = np.array([span.throughput for span in candidate_spans], dtype=np.float64) + span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) elif mode == "min_latency": span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64) else: diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 8d7d50b..e3a3e03 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -19,10 +19,17 @@ class ServerState(Enum): ONLINE = 2 +RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + + @pydantic.dataclasses.dataclass class ServerInfo: state: ServerState - throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + throughput: RPS + + network_rps: Optional[RPS] = None + forward_rps: Optional[RPS] = None + inference_rps: Optional[RPS] = None adapters: Sequence[str] = () version: Optional[str] = None @@ -60,7 +67,7 @@ class RemoteSpanInfo: peer_id: PeerID start: int end: int - throughput: float + server_info: ServerInfo @property def length(self): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index f09724f..aea57c7 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -206,7 +206,7 @@ class Server: assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: - throughput = get_server_throughput( + throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, device, @@ -217,14 +217,16 @@ class Server: force_eval=(throughput == "eval"), cache_dir=cache_dir, ) + else: + throughput_info = {"throughput": throughput} self.server_info = ServerInfo( state=ServerState.JOINING, - throughput=throughput, adapters=tuple(adapters), version=petals.__version__, torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), using_relay=self.dht.client_mode, + **throughput_info, ) self.balance_quality = balance_quality diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 20625e6..d92355e 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -43,13 +43,13 @@ def get_server_throughput( tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, cache_dir: Optional[str] = None, -) -> float: +) -> Dict[str, float]: dtype = resolve_block_dtype(config, dtype) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, "throughput.lock") - cache_path = Path(cache_dir, "throughput_v3.json") + cache_path = Path(cache_dir, "throughput_v4.json") # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) @@ -93,10 +93,12 @@ def get_server_throughput( # Assuming the start block index is distributed uniformly, the average number of blocks used per request is # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2 average_blocks_used = (num_blocks + 1) / 2 - throughput = throughput_info["compute_rps"] / average_blocks_used + throughput = throughput_info["forward_rps"] / average_blocks_used throughput = min(throughput, throughput_info.get("network_rps", math.inf)) + throughput_info["throughput"] = throughput logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks") - return throughput + + return throughput_info def measure_throughput_info( @@ -114,15 +116,31 @@ def measure_throughput_info( ) throughput_info = { - "compute_rps": measure_compute_rps( - config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices - ) + "inference_rps": measure_compute_rps( + config, + device, + dtype, + quant_type=quant_type, + tensor_parallel_devices=tensor_parallel_devices, + n_tokens=1, + n_steps=100, + inference=True, + ), + "forward_rps": measure_compute_rps( + config, + device, + dtype, + quant_type=quant_type, + tensor_parallel_devices=tensor_parallel_devices, + n_tokens=1024, + n_steps=10, + inference=False, + ), } try: throughput_info["network_rps"] = measure_network_rps(config) except Exception as e: - logger.warning(f"Failed to measure network throughput: {repr(e)}") - logger.warning("Proceeding with the compute throughput only") + logger.info(f"Network throughput is not available: {e}") return throughput_info @@ -135,6 +153,8 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt process.terminate() raise RuntimeError(f"speedtest did not finish in {timeout} seconds") network_info = pipe_recv.recv() + if "exception" in network_info: + raise RuntimeError(f"speedtest failed: {network_info['exception']}") bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request @@ -150,12 +170,15 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt def _measure_bits_per_second(pipe_send: mp.Pipe): - s = speedtest.Speedtest() - s.get_servers() - s.get_best_server() - s.download() - s.upload() - pipe_send.send(s.results.dict()) + try: + s = speedtest.Speedtest() + s.get_servers() + s.get_best_server() + s.download() + s.upload() + pipe_send.send(s.results.dict()) + except Exception as e: + pipe_send.send({"exception": repr(e)}) def measure_compute_rps( @@ -165,8 +188,9 @@ def measure_compute_rps( *, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], - n_tokens: int = 16, - n_steps: int = 500, + n_tokens: int, + n_steps: int, + inference: bool, ) -> float: if not tensor_parallel_devices: tensor_parallel_devices = (device,) @@ -180,7 +204,7 @@ def measure_compute_rps( dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) start_time = time.perf_counter() - _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache) + _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) if step >= 1: # Skip the 1st step to exclude the initialization time elapsed += time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed @@ -191,8 +215,8 @@ def measure_compute_rps( devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) logger.info( - f"Forward pass throughput: {device_rps:.1f} RPS per block " - f"({devices_repr}, {get_dtype_name(dtype, quant_type)})" + f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block " + f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})" ) return device_rps @@ -202,7 +226,7 @@ def get_device_name(device: torch.device) -> str: def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: - name = str(dtype) + name = str(dtype).replace("torch.", "") if quant_type != QuantType.NONE: name += f", quantized to {quant_type.name.lower()}" return name diff --git a/src/petals/utils/ping.py b/src/petals/utils/ping.py index de675a5..d5fd129 100644 --- a/src/petals/utils/ping.py +++ b/src/petals/utils/ping.py @@ -16,7 +16,7 @@ async def ping( _dht: hivemind.DHT, node: hivemind.dht.DHTNode, *, - wait_timeout: float = 1, + wait_timeout: float = 5, ) -> float: try: ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index e5b450c..64c9c6a 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -24,8 +24,10 @@ def test_bnb_not_imported_when_unnecessary(): @pytest.mark.forked +@pytest.mark.parametrize("inference", [False, True]) +@pytest.mark.parametrize("n_tokens", [1, 16]) @pytest.mark.parametrize("tensor_parallel", [False, True]) -def test_compute_throughput(tensor_parallel: bool): +def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( @@ -34,6 +36,8 @@ def test_compute_throughput(tensor_parallel: bool): dtype=torch.bfloat16, quant_type=QuantType.NONE, tensor_parallel_devices=tensor_parallel_devices, - n_steps=10, + n_tokens=n_tokens, + n_steps=5, + inference=inference, ) assert isinstance(compute_rps, float) and compute_rps > 0 From fd30f7ce103545a11b7d4d9c02131b6d1dab4a1a Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Tue, 18 Jul 2023 10:44:41 +0900 Subject: [PATCH 121/168] Fix typo in generation_algorithms.py (#364) --- src/petals/utils/generation_algorithms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/utils/generation_algorithms.py b/src/petals/utils/generation_algorithms.py index d58f073..d085e8b 100644 --- a/src/petals/utils/generation_algorithms.py +++ b/src/petals/utils/generation_algorithms.py @@ -16,7 +16,7 @@ class DecodingAlgorithm(ABC): @abstractmethod def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: """ - :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size) + :param logits: A tensor of shape (batch_size, seq_length, vocab_size) :return: A tuple of selected token ids and corresponding hypotheses. The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size) """ From 62d9ed5ce7b08f621d7947679087fd54e8df723b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 18 Jul 2023 08:46:36 +0400 Subject: [PATCH 122/168] Implement shortest-path routing for inference (#362) This PR: 1. **Adds shortest path routing for inference.** We build a graph with client-server and server-server latencies and compute costs, as well as empirically measured overheads. For client-server latencies, we ping possible first and last servers in a sequence in `SequenceManager.update()`. We penalize servers who may not have enough cache for our request. This uses info added to DHT in #355, #356, #358. 2. **Makes a server ping neighboring servers in addition to next ones.** This is to get an opportunity to change the server even before we use all its blocks (e.g., because a neighboring server is faster). This feature is not enabled though, since it increases graph size for N servers to O(N^2) - but we may enable it if needed. 3. **Fixes a `SequenceManager` bug with the first `update()`.** Previously, this update was likely to produce incorrect information and cause to `MissingBlocksErrors` until the next update happens. --- setup.cfg | 1 + src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 2 +- src/petals/client/inference_session.py | 4 +- src/petals/client/routing/sequence_manager.py | 196 ++++++++++++++++-- src/petals/server/server.py | 31 +-- src/petals/utils/ping.py | 29 +-- src/petals/utils/random.py | 12 ++ 8 files changed, 222 insertions(+), 55 deletions(-) create mode 100644 src/petals/utils/random.py diff --git a/setup.cfg b/setup.cfg index 7fc930e..10f56b5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,6 +48,7 @@ install_requires = sentencepiece>=0.1.99 peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735 safetensors>=0.3.1 + Dijkstar>=2.6.0 [options.extras_require] dev = diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 3e67633..d02dbeb 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev2" +__version__ = "1.2.0.dev3" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index f2a0168..ce69974 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -84,7 +84,7 @@ def main(): parser.add_argument('--attn_cache_tokens', type=int, default=8192, help='The number of past attention key/value pairs that will be stored between inference steps. ' 'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).') - parser.add_argument('--alloc_timeout', type=float, default=60, + parser.add_argument('--alloc_timeout', type=float, default=5, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') parser.add_argument('--revision', type=str, default=None, diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 8c2dfc9..0e5d6b4 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -340,7 +340,9 @@ class InferenceSession: f"from block {block_idx} to {update_end} will be regenerated" ) - updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency") + updated_spans = self._sequence_manager.make_sequence( + block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length + ) # make_sequence() could return a longer sequence updated_spans[-1].end = min(updated_spans[-1].end, update_end) updated_sessions = self._enter_server_sessions(updated_spans) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 19b475b..5b1ab3f 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -10,6 +10,7 @@ import time from typing import Any, Collection, Dict, List, Optional, Sequence, Union from weakref import WeakMethod +import dijkstar import numpy as np from hivemind import DHT, P2P, MSGPackSerializer, PeerID from hivemind.dht.node import Blacklist @@ -23,6 +24,8 @@ from petals.client.routing.spending_policy import NoSpendingPolicy from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler +from petals.utils.ping import PingAggregator +from petals.utils.random import sample_up_to logger = get_logger(__name__) @@ -33,6 +36,7 @@ class SequenceManagerConfig: dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name) daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers + show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True] allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers use_server_to_server: bool = True # Use direct server-to-server communication @@ -43,7 +47,10 @@ class SequenceManagerConfig: min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds - active_adapter: Optional[str] = None + active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) + + max_pinged: int = 5 # max servers to ping from each sequence side, per update + ping_timeout: float = 2 # max time to wait for pings, per update @dataclasses.dataclass @@ -79,7 +86,6 @@ class RemoteSequenceManager: *, dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, - active_adapter: Optional[str] = None, ): assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." @@ -94,7 +100,7 @@ class RemoteSequenceManager: dht = DHT( initial_peers=config.initial_peers, client_mode=True, - num_workers=config.num_hidden_layers, + num_workers=32, startup_timeout=config.daemon_startup_timeout, start=True, ) @@ -109,25 +115,25 @@ class RemoteSequenceManager: self._thread_start_lock = threading.Lock() self.policy = NoSpendingPolicy() + self.ping_aggregator = PingAggregator(dht) + if state.banned_peers is None: state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0) if state.sequence_info is None: state.sequence_info = RemoteSequenceInfo.make_empty(block_uids) - if state.sequence_info.last_updated_time is None: - # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records - # in the first _update() instead of the latest ones. This makes the first .update() faster. - petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True - ) - self._need_latest_infos = False - else: + if state.sequence_info.last_updated_time is not None: assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch self._need_latest_infos = True def make_sequence( - self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str + self, + start_index: int = 0, + end_index: Optional[int] = None, + *, + mode: str, + cache_tokens_needed: Optional[int] = None, ) -> List[RemoteSpanInfo]: """ Form a sequence of remote servers that collectively serve all consecutive layers @@ -143,6 +149,150 @@ class RemoteSequenceManager: self.update(wait=True) # this will await an existing update or trigger a new one (if not updating) end_index = end_index if end_index is not None else len(self) + + if mode == "min_latency": + span_sequence = self._make_sequence_with_min_latency( + start_index, end_index, cache_tokens_needed=cache_tokens_needed + ) + elif mode == "max_throughput": + span_sequence = self._make_sequence_with_max_throughput(start_index, end_index) + else: + raise RuntimeError(f"Unexpected mode {mode}") + + if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"): + route_repr = " => ".join( + [f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence] + ) + logger.info(f"Route found: {route_repr}") + return span_sequence + + def _make_sequence_with_min_latency( + self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int] + ) -> List[RemoteSpanInfo]: + if start_index == end_index: + return [] + + with self.lock_changes: + missing_blocks = [ + block_idx + for block_idx in range(start_index, end_index) + if not self.state.sequence_info.spans_containing_block[block_idx] + ] + if missing_blocks: + raise MissingBlocksError(missing_blocks) + server_infos = { + span.peer_id: span.server_info + for block_idx in range(start_index, end_index) + for span in self.state.sequence_info.spans_containing_block[block_idx] + } + + graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed) + + path = dijkstar.find_path(graph, "start", "end") + logger.debug(f"Path info: {path}") + if start_index == 0 and end_index == len(self): + logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec") + + span_sequence = [] + for peer_id, block_idx in path.nodes[1:-1]: + if not span_sequence or span_sequence[-1].peer_id != peer_id: + span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id])) + else: + span_sequence[-1].end = block_idx + + # Remove empty spans that can appear if we don't force to go to the end of each server and network delay + # don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors + span_sequence = [span for span in span_sequence if span.length > 0] + + return span_sequence + + def _build_inference_graph( + self, + start_index: int, + end_index: int, + *, + cache_tokens_needed: Optional[int], + overhead_coeff: float = 1.82, # Backend overhead (empirically measured) + overhead_delay: float = 0.018, # Serialization overhead (empirically measured) + default_inference_rps: float = 300, # If inference RPS unknown + alloc_delay: float = 10, # If not enough cache left, we penalize the edge + ) -> dijkstar.Graph: + missing_blocks = [ + block_idx + for block_idx in range(start_index, end_index) + if not self.state.sequence_info.spans_containing_block[block_idx] + ] + if missing_blocks: + raise MissingBlocksError(missing_blocks) + + client_server_rtts = self.ping_aggregator.to_dict() + + graph = dijkstar.Graph() + + # Clent -> server network delays + for span in self.state.sequence_info.spans_containing_block[start_index]: + delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id)) + delay += overhead_delay + if not self._has_cache_for(span, cache_tokens_needed): + delay += alloc_delay + graph.add_edge("start", (span.peer_id, start_index), delay) + + # Server -> client network delays + for span in self.state.sequence_info.spans_containing_block[end_index - 1]: + delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id)) + graph.add_edge((span.peer_id, end_index), "end", delay) + + # Server -> server network delays + for block_idx in range(start_index + 1, end_index): + for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]: + if cur_span.end != block_idx: + # If we choose a server, we force to go to the end of it before switching to a new one + # to avoid O(N^2) graphs for N servers + continue + + for next_span in self.state.sequence_info.spans_containing_block[block_idx]: + rtt = None + if cur_span.server_info.next_pings is not None: + rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58()) + delay = self._rtt_to_delay(rtt) + delay += overhead_delay + if not self._has_cache_for(next_span, cache_tokens_needed): + delay += alloc_delay + graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay) + + # Compute delays + for span in self.state.sequence_info.spans_by_priority: + for block_idx in range(max(span.start, start_index), min(span.end, end_index)): + inference_rps = span.server_info.inference_rps + if inference_rps is None: + inference_rps = default_inference_rps + graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), overhead_coeff / inference_rps) + + return graph + + @staticmethod + def _rtt_to_delay( + rtt: float, + *, + default_delay: float = 0.15, # If network delay unknown + max_delay: float = 5, # If unreachable, we don't want to discard the edge completely + ) -> float: + if rtt is None: + return default_delay + return min(rtt / 2, max_delay) + + @staticmethod + def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool: + if cache_tokens_needed is None or span.server_info.cache_tokens_left is None: + return True + + # Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through + # this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage, + # so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate. + # This is okay since false positives are more costly than false negatives here. + return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left + + def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: span_sequence = [] current_index = start_index while current_index < end_index: @@ -150,20 +300,12 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - if mode == "max_throughput": - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) - elif mode == "min_latency": - span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64) - else: - raise RuntimeError(f"Unexpected mode {mode}") + span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end span_sequence.append(dataclasses.replace(chosen_span, start=current_index)) current_index = chosen_span.end - - route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]) - logger.debug(f"Route found: {route_repr}") return span_sequence def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager: @@ -182,10 +324,10 @@ class RemoteSequenceManager: def _update(self): """Perform an immediate and synchronous refresh, may take time""" + new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos + self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True ) - self._need_latest_infos = True # All future _update() should use latest infos for block_info in new_block_infos: if not block_info: @@ -217,6 +359,14 @@ class RemoteSequenceManager: with self.lock_changes: self.state.sequence_info.update_(new_block_infos) + + first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]] + last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]] + + pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged)) + pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged)) + self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout) + self.ready.set() def on_request_failure(self, peer_id: Optional[PeerID]): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index aea57c7..0793dff 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -32,6 +32,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.ping import PingAggregator +from petals.utils.random import sample_up_to from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -61,7 +62,7 @@ class Server: cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, attn_cache_tokens: int = 8192, - alloc_timeout: float = 60, + alloc_timeout: float = 5, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, @@ -637,7 +638,6 @@ class ModuleAnnouncerThread(threading.Thread): update_period: float, expiration: float, max_pinged: int = 5, - max_reported: int = 10, **kwargs, ): super().__init__(**kwargs) @@ -650,10 +650,11 @@ class ModuleAnnouncerThread(threading.Thread): self.expiration = expiration self.trigger = threading.Event() - self.max_pinged, self.max_reported = max_pinged, max_reported - last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1])) - dht_prefix, block_index = last_uid.split(UID_DELIMITER) - self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}" + self.max_pinged = max_pinged + dht_prefix = module_uids[0].split(UID_DELIMITER)[0] + block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] + start_block, end_block = min(block_indices), max(block_indices) + 1 + self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: @@ -664,7 +665,7 @@ class ModuleAnnouncerThread(threading.Thread): if self.server_info.state != ServerState.OFFLINE: self._ping_next_servers() self.server_info.next_pings = { - peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items() + peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items() } else: self.server_info.next_pings = None # No need to ping if we're disconnecting @@ -691,14 +692,14 @@ class ModuleAnnouncerThread(threading.Thread): self.join() def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: - [module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True) - if module_info is None: - return - - next_servers = list(module_info.servers) - if len(next_servers) > self.max_pinged: - next_servers = random.sample(next_servers, self.max_pinged) - self.ping_aggregator.ping(next_servers) + module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True) + middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers} + pinged_servers = set(sample_up_to(middle_servers, self.max_pinged)) + pinged_servers.discard(self.dht.peer_id) + if module_infos[-1] is not None: + # Sample servers hosting the block after the last one (most likely continuations) separately + pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) + self.ping_aggregator.ping(list(pinged_servers)) class RuntimeWithDeduplicatedPools(Runtime): diff --git a/src/petals/utils/ping.py b/src/petals/utils/ping.py index d5fd129..4245bf4 100644 --- a/src/petals/utils/ping.py +++ b/src/petals/utils/ping.py @@ -1,5 +1,6 @@ import asyncio import math +import threading import time from functools import partial from typing import Dict, Sequence @@ -34,27 +35,27 @@ async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> class PingAggregator: - def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600): + def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300): self.dht = dht self.ema_alpha = ema_alpha self.expiration = expiration self.ping_emas = hivemind.TimedStorage() + self.lock = threading.Lock() - def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs): + def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None: current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs)) logger.debug(f"Current RTTs: {current_rtts}") - expiration = hivemind.get_dht_time() + self.expiration - for peer_id, rtt in current_rtts.items(): - prev_rtt = self.ping_emas.get(peer_id) - if prev_rtt is not None and prev_rtt.value != math.inf: - rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing - self.ping_emas.store(peer_id, rtt, expiration) + with self.lock: + expiration = hivemind.get_dht_time() + self.expiration + for peer_id, rtt in current_rtts.items(): + prev_rtt = self.ping_emas.get(peer_id) + if prev_rtt is not None and prev_rtt.value != math.inf: + rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing + self.ping_emas.store(peer_id, rtt, expiration) - def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]: - with self.ping_emas.freeze(): + def to_dict(self) -> Dict[hivemind.PeerID, float]: + with self.lock, self.ping_emas.freeze(): smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()} - logger.debug(f"Smothed RTTs: {smoothed_rtts}") - - fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers] - return dict(fastest_rtts) + logger.debug(f"Smothed RTTs: {smoothed_rtts}") + return smoothed_rtts diff --git a/src/petals/utils/random.py b/src/petals/utils/random.py new file mode 100644 index 0000000..15635ff --- /dev/null +++ b/src/petals/utils/random.py @@ -0,0 +1,12 @@ +import random +from typing import Collection, TypeVar + +T = TypeVar("T") + + +def sample_up_to(population: Collection[T], k: int) -> T: + if not isinstance(population, list): + population = list(population) + if len(population) > k: + population = random.sample(population, k) + return population From 3b300c32e4027f4e4916d3d527c34bb8260be853 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 18 Jul 2023 19:57:39 +0400 Subject: [PATCH 123/168] Update readme to show new models (#365) --- README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 5c3a4c5..a30bcc2 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Generate text with distributed [LLaMA-65B](https://github.com/facebookresearch/l from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM -model_name = "bigscience/bloom" # You can use any Hugging Face hub repo with a supported model +model_name = "enoch/llama-65b-hf" # You can also use "bigscience/bloom" or "bigscience/bloomz" tokenizer = AutoTokenizer(model_name) model = AutoDistributedModelForCausalLM.from_pretrained(model_name) # Embeddings & prompts are on your device, transformer blocks are distributed across the Internet @@ -25,10 +25,10 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🚀  Try now in Colab

-🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. - 📋 Make sure you follow the model's terms of use (see [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license) licenses). Note that LLaMA is available for non-commercial purposes only, and you have to file a request [here](https://bit.ly/llama-license) to use it in your own projects. +🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. + ### Connect your GPU and increase Petals capacity Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): @@ -36,21 +36,21 @@ Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linu ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install git+https://github.com/bigscience-workshop/petals -python -m petals.cli.run_server bigscience/bloom +python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` Or run our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)): ```bash -sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ - learningathome/petals:main python -m petals.cli.run_server bigscience/bloom --port 31330 +sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ + python -m petals.cli.run_server --port 31330 enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` -🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). +This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. -📚 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to configure the server to use multiple GPUs, address common issues, etc. +🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). -💬 If you have any issues or feedback, let us know on [our Discord server](https://discord.gg/D9MwApKgWa)! +💬 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! ### Check out tutorials, examples, and more @@ -94,7 +94,7 @@ Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/d ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install -U petals +pip install git+https://github.com/bigscience-workshop/petals ``` If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). From f97582fb5fc23673452dd4702cb192bf94809f88 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 02:35:47 +0400 Subject: [PATCH 124/168] Require transformers < 4.31.0 until we're compatible (#369) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 10f56b5..27be9ac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ install_requires = accelerate>=0.16.0,<0.21.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers>=4.30.1,<5.0.0 + transformers>=4.30.1,<4.31.0 speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8 hivemind==1.1.8 From a6fdfc0556ed9590a9efed1a5c211b3f001d8167 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 03:22:19 +0400 Subject: [PATCH 125/168] Fix AssertionError on rebalancing (#370) --- src/petals/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 0793dff..2a7904f 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -418,7 +418,7 @@ class ModuleContainer(threading.Thread): module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) - assert server_info.state == ServerState.JOINING + server_info.state = ServerState.JOINING dht_announcer = ModuleAnnouncerThread( module_uids, dht, From 1ab35c282679043efc77423e6ca1c2418b429ac2 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 19 Jul 2023 02:22:40 +0300 Subject: [PATCH 126/168] Typo in inference_session.py --- src/petals/client/inference_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 0e5d6b4..5e14d8a 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -96,7 +96,7 @@ class _ServerInferenceSession: step_id: str, ) -> torch.Tensor: """ - Inference step: send a chunk of input tesors and receive a chunk of outputs + Inference step: send a chunk of input tensors and receive a chunk of outputs :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ From c735dd7ba3d5d0115c0a80c6bc04163aabd689ea Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 05:15:30 +0400 Subject: [PATCH 127/168] Update transformers to 4.31.0 and peft to 0.4.0 (#371) --- .github/workflows/run-tests.yaml | 2 +- README.md | 2 +- setup.cfg | 8 ++++---- src/petals/__init__.py | 4 ++-- src/petals/cli/run_server.py | 2 +- src/petals/models/bloom/model.py | 16 ++++------------ src/petals/models/llama/model.py | 9 ++++----- src/petals/server/from_pretrained.py | 20 ++++++++++---------- src/petals/server/server.py | 16 ++++++++-------- src/petals/utils/peft.py | 23 +++++++++++++++-------- 10 files changed, 50 insertions(+), 52 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index b98667e..a81592b 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.7', '3.8', '3.9', '3.10' ] + python-version: [ '3.8', '3.9', '3.10' ] fail-fast: false timeout-minutes: 15 steps: diff --git a/README.md b/README.md index a30bcc2..3955835 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... ### Connect your GPU and increase Petals capacity -Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+): +Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+): ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia diff --git a/setup.cfg b/setup.cfg index 27be9ac..1e976a6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,9 +15,9 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: MIT License Programming Language :: Python :: 3 - Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering Topic :: Scientific/Engineering :: Mathematics Topic :: Scientific/Engineering :: Artificial Intelligence @@ -29,14 +29,14 @@ classifiers = package_dir = = src packages = find: -python_requires = >=3.7 +python_requires = >=3.8 install_requires = torch>=1.12 bitsandbytes==0.40.1.post1 accelerate>=0.16.0,<0.21.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers>=4.30.1,<4.31.0 + transformers>=4.31.0,<5.0.0 speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8 hivemind==1.1.8 @@ -46,7 +46,7 @@ install_requires = cpufeature>=0.2.0 packaging>=20.9 sentencepiece>=0.1.99 - peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735 + peft>=0.4.0 safetensors>=0.3.1 Dijkstar>=2.6.0 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index d02dbeb..b72776c 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -16,8 +16,8 @@ __version__ = "1.2.0.dev3" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( - version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0") - ), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0" + version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0") + ), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0" def _override_bfloat16_mode_default(): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index ce69974..c7264b4 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -132,7 +132,7 @@ def main(): parser.add_argument("--mean_balance_check_period", type=float, default=60, help="Check the swarm's balance every N seconds (and rebalance it if necessary)") - parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained") + parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()") parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType], help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or " "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. " diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index 7644148..e03adca 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -20,9 +20,7 @@ logger = get_logger(__name__) class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): """BloomModel, but all transformer layers are hosted by the swarm""" - _keys_to_ignore_on_load_missing = ( - BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing - ) + _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = [r"^h\."] config_class = DistributedBloomConfig @@ -93,11 +91,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM): - _keys_to_ignore_on_load_missing = ( - BloomForCausalLM._keys_to_ignore_on_load_missing - + DistributedBloomModel._keys_to_ignore_on_load_missing - + [r"^lm_head\."] # Missing since they are shared with input embeddings - ) + _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected config_class = DistributedBloomConfig @@ -115,10 +110,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification): - _keys_to_ignore_on_load_missing = ( - BloomForSequenceClassification._keys_to_ignore_on_load_missing - + DistributedBloomModel._keys_to_ignore_on_load_missing - ) + _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected config_class = DistributedBloomConfig diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index 244207b..cafb45b 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -21,7 +21,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): """LlamaModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing - _keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."] + _keys_to_ignore_on_load_unexpected = [r"^model\.layers\."] config_class = DistributedLlamaConfig @@ -115,6 +115,8 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll def __init__(self, config: DistributedLlamaConfig): LlamaPreTrainedModel.__init__(self, config) self.model = DistributedLlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size self.lm_head = LMHead(config) # Initialize weights and apply final processing @@ -129,10 +131,7 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification): - _keys_to_ignore_on_load_missing = ( - LlamaForSequenceClassification._keys_to_ignore_on_load_missing - + DistributedLlamaModel._keys_to_ignore_on_load_missing - ) + _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected config_class = DistributedLlamaConfig diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 41fb989..9898759 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -34,12 +34,12 @@ def load_pretrained_block( config: Optional[PretrainedConfig] = None, torch_dtype: Union[torch.dtype, str] = "auto", revision: Optional[str] = None, - use_auth_token: Optional[str] = None, + token: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, ) -> nn.Module: if config is None: - config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token) + config = AutoDistributedConfig.from_pretrained(model_name, token=token) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR @@ -54,7 +54,7 @@ def load_pretrained_block( model_name, block_prefix, revision=revision, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) @@ -82,12 +82,12 @@ def _load_state_dict_from_repo( block_prefix: str, *, revision: Optional[str] = None, - use_auth_token: Optional[str] = None, + token: Optional[str] = None, cache_dir: str, max_disk_space: Optional[int] = None, ) -> StateDict: index_file = get_file_from_repo( - model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir + model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir ) if index_file is not None: # Sharded model with open(index_file) as f: @@ -107,7 +107,7 @@ def _load_state_dict_from_repo( model_name, filename, revision=revision, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) @@ -125,7 +125,7 @@ def _load_state_dict_from_file( filename: str, *, revision: Optional[str] = None, - use_auth_token: Optional[str] = None, + token: Optional[str] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, @@ -137,7 +137,7 @@ def _load_state_dict_from_file( model_name, filename, revision=revision, - use_auth_token=use_auth_token, + use_auth_token=token, cache_dir=cache_dir, local_files_only=True, ) @@ -151,7 +151,7 @@ def _load_state_dict_from_file( try: with allow_cache_writes(cache_dir): url = hf_hub_url(model_name, filename, revision=revision) - file_size = get_hf_file_metadata(url, token=use_auth_token).size + file_size = get_hf_file_metadata(url, token=token).size if file_size is not None: free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) else: @@ -161,7 +161,7 @@ def _load_state_dict_from_file( model_name, filename, revision=revision, - use_auth_token=use_auth_token, + use_auth_token=token, cache_dir=cache_dir, local_files_only=False, ) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 2a7904f..d061d0a 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -77,7 +77,7 @@ class Server: balance_quality: float = 0.75, mean_balance_check_period: float = 120, mean_block_selection_delay: float = 2.5, - use_auth_token: Optional[str] = None, + token: Optional[str] = None, quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, @@ -98,14 +98,14 @@ class Server: self.compression = compression self.stats_report_interval, self.update_period = stats_report_interval, update_period self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads - self.revision, self.use_auth_token = revision, use_auth_token + self.revision, self.token = revision, token if custom_module_path is not None: add_custom_models_from_file(custom_module_path) self.block_config = AutoDistributedConfig.from_pretrained( converted_model_name_or_path, - use_auth_token=use_auth_token, + token=token, revision=revision, ) @@ -271,7 +271,7 @@ class Server: self.block_config, self.torch_dtype, self.adapters, - use_auth_token=self.use_auth_token, + token=self.token, cache_dir=self.cache_dir, max_disk_space=self.max_disk_space, ) @@ -316,7 +316,7 @@ class Server: prefetch_batches=self.prefetch_batches, sender_threads=self.sender_threads, revision=self.revision, - use_auth_token=self.use_auth_token, + token=self.token, quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, @@ -409,7 +409,7 @@ class ModuleContainer(threading.Thread): update_period: float, expiration: Optional[float], revision: Optional[str], - use_auth_token: Optional[str], + token: Optional[str], quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, @@ -443,7 +443,7 @@ class ModuleContainer(threading.Thread): config=block_config, torch_dtype=torch_dtype, revision=revision, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) @@ -456,7 +456,7 @@ class ModuleContainer(threading.Thread): quant_type, adapters=server_info.adapters, freeze=True, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index b182181..bbad779 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -45,13 +45,20 @@ def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", d return tensors -def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs): - config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs) +def get_adapter_from_repo( + repo_id: str, + block_idx: Optional[int] = None, + device: Optional[int] = None, + *, + token: Optional[str] = None, + **kwargs, +): + config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs) if config_path is None: raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}") config = PeftConfig.from_json_file(config_path) - weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs) + weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs) if weight_path is None: raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}") if block_idx is None: @@ -65,7 +72,7 @@ def load_peft( device: Optional[int] = None, *, revision: Optional[str] = None, - use_auth_token: Optional[str] = None, + token: Optional[str] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, @@ -82,7 +89,7 @@ def load_peft( block_idx, device, revision=revision, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, local_files_only=False, ) @@ -93,9 +100,9 @@ def load_peft( try: with allow_cache_writes(cache_dir): config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision) - config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size + config_file_size = get_hf_file_metadata(config_url, token=token).size weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) - weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size + weight_file_size = get_hf_file_metadata(weight_url, token=token).size file_size = config_file_size + weight_file_size if file_size is not None: @@ -108,7 +115,7 @@ def load_peft( block_idx, device, revision=revision, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, local_files_only=False, ) From 895327a0aec44c59456e5de0f4ce1a4a4cc8f97a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 12:45:14 +0400 Subject: [PATCH 128/168] Fix readme code example, require Python < 3.11 until supported (#374) * Fix readme code example * Require Python < 3.11 until it's supported --- README.md | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3955835..35d5cba 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM model_name = "enoch/llama-65b-hf" # You can also use "bigscience/bloom" or "bigscience/bloomz" -tokenizer = AutoTokenizer(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoDistributedModelForCausalLM.from_pretrained(model_name) # Embeddings & prompts are on your device, transformer blocks are distributed across the Internet diff --git a/setup.cfg b/setup.cfg index 1e976a6..7341684 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ classifiers = package_dir = = src packages = find: -python_requires = >=3.8 +python_requires = >=3.8,<3.11 install_requires = torch>=1.12 bitsandbytes==0.40.1.post1 From 5a8de2f1f8173bce927381eb575d83e91dc90315 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 19 Jul 2023 12:31:47 +0300 Subject: [PATCH 129/168] Fix handler memory leak, get rid of mp.Manager (#373) This PR removes the memory leak from somewhere within handler.py that has something to do with mp.SyncManager. --- src/petals/server/handler.py | 189 ++++++++++++++++++++++------------- src/petals/server/server.py | 11 +- 2 files changed, 121 insertions(+), 79 deletions(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d0531de..5d0a3d4 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -2,9 +2,9 @@ from __future__ import annotations import asyncio import contextlib -import multiprocessing.managers +import multiprocessing as mp import sys -from concurrent.futures import ThreadPoolExecutor +from enum import Enum from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -42,20 +42,15 @@ logger = get_logger(__name__) # Fix pickling protobufs, see https://stackoverflow.com/a/74873028 sys.modules["runtime_pb2"] = runtime_pb2 -# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256 -_OriginalAutoProxy = multiprocessing.managers.AutoProxy - - -def patched_autoproxy(*args, manager_owned=True, **kwargs): - # Calling original AutoProxy without the unwanted key argument - return _OriginalAutoProxy(*args, **kwargs) - - -multiprocessing.managers.AutoProxy = patched_autoproxy +CACHE_TOKENS_AVAILABLE = "cache_tokens_available" -CACHE_TOKENS_AVAILABLE = "cache_tokens_available" +class Event(Enum): + NEW_SESSION = 0 + END_SESSION = 1 + PUSH = 2 + SHUTDOWN = 3 class TransformerConnectionHandler(ConnectionHandler): @@ -70,8 +65,8 @@ class TransformerConnectionHandler(ConnectionHandler): *, adapters: Optional[Sequence[str]], dht_prefix: str, - push_manager: multiprocessing.managers.SyncManager, - session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue + handler_event_queues: Sequence[mp.Queue], + handler_index: int, inference_max_length: int, request_timeout: float, session_timeout: float, @@ -83,18 +78,28 @@ class TransformerConnectionHandler(ConnectionHandler): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix self.adapters = adapters - self._push_manager = push_manager - self._session_queues = session_queues - self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues + self._handler_event_queues = handler_event_queues + self._handler_index = handler_index + self._own_event_queue = handler_event_queues[handler_index] + self._listener_task: Optional[asyncio.Task] = None + self._session_queues: Dict[str, asyncio.Queue] = {} + self._session_handlers: Dict[str, int] = {} self.inference_max_length = inference_max_length self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout self._prioritizer = task_prioritizer + async def add_p2p_handlers(self, *args, **kwargs) -> None: + if self._listener_task is None: + # Start listening to our own event queue before we accept any requests + self._listener_task = asyncio.create_task(self._listen_to_event_queue()) + await super().add_p2p_handlers(*args, **kwargs) + def shutdown(self): if self.is_alive(): self._outer_pipe.send("_shutdown") + self._own_event_queue.put((Event.SHUTDOWN, None, None)) self.join(self.shutdown_timeout) if self.is_alive(): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") @@ -129,7 +134,6 @@ class TransformerConnectionHandler(ConnectionHandler): context: P2PContext, ) -> AsyncIterator[runtime_pb2.ExpertResponse]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" - async with timeout(self.session_timeout): try: request = await asyncio.wait_for(anext(requests), self.step_timeout) @@ -146,7 +150,6 @@ class TransformerConnectionHandler(ConnectionHandler): active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") - if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( @@ -235,6 +238,56 @@ class TransformerConnectionHandler(ConnectionHandler): finally: self._log_request("rpc_inference.close", requested_uids, context) + @contextlib.contextmanager + def _managed_session(self, session_id: str): + assert session_id not in self._session_queues, f"session id {session_id} is not unique" + try: + self._session_queues[session_id] = asyncio.Queue() + self._session_handlers[session_id] = self._handler_index + for other_index, other_queue in enumerate(self._handler_event_queues): + if other_index != self._handler_index: + other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index)) + yield + finally: + self._session_queues.pop(session_id).put_nowait(None) # put None so that the get task will not hang + del self._session_handlers[session_id] + for other_index, other_queue in enumerate(self._handler_event_queues): + if other_index != self._handler_index: + other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index)) + + def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest): + handler_index = self._session_handlers.get(session_id) + if handler_index is None: + logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}") + elif handler_index == self._handler_index: + self._session_queues[session_id].put_nowait(request) + else: + self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request)) + + async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]: + assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler" + return await self._session_queues[session_id].get() + + async def _listen_to_event_queue(self): + loop = asyncio.get_event_loop() + while True: + try: + event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get) + if event == Event.SHUTDOWN: + break + elif event == Event.NEW_SESSION: + self._session_handlers[session_id] = payload # index of the handler that owns that session + elif event == Event.END_SESSION: + self._session_handlers.pop(session_id, None) + elif event == Event.PUSH: + maybe_session_queue = self._session_queues.get(session_id) + if maybe_session_queue is not None: + maybe_session_queue.put_nowait(payload) + else: + raise RuntimeError(f"Unexpected event: {event}") + except Exception as e: + logger.exception(e) + async def _iterate_inference_steps( self, first_request: runtime_pb2.ExpertRequest, @@ -243,67 +296,60 @@ class TransformerConnectionHandler(ConnectionHandler): requested_uids: Sequence[str], context: P2PContext, ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]: - loop = asyncio.get_event_loop() - if session_id is not None: - push_queue = self._push_manager.Queue() - self._session_queues[session_id] = push_queue - processed_step_ids = set() n_pushes = n_late_pushes = 0 request = first_request anext_task = get_push_task = None try: - while request.tensors: # iterate while user is willing to supply tensors - metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - step_id = metadata.get("step_id") - - pushed = metadata.get("pushed") - if pushed: - n_pushes += 1 - - if step_id is None or step_id not in processed_step_ids: - yield request, metadata - if step_id is not None: - processed_step_ids.add(step_id) - elif pushed: - n_late_pushes += 1 - self._log_request( - "rpc_inference.push", - requested_uids, - context, - warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", + with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext(): + while request.tensors: # iterate while user is willing to supply tensors + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + step_id = metadata.get("step_id") + + pushed = metadata.get("pushed") + if pushed: + n_pushes += 1 + self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push") + + if step_id is None or step_id not in processed_step_ids: + yield request, metadata + if step_id is not None: + processed_step_ids.add(step_id) + elif pushed: + n_late_pushes += 1 + self._log_request( + "rpc_inference.push", + requested_uids, + context, + warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", + ) + + # Wait for the next request, coming either from the `requests` iterator or `push_queue` + if anext_task is None: + anext_task = asyncio.create_task(anext(requests)) + if get_push_task is None: + if session_id is not None: + get_push_task = asyncio.create_task(self._get_from_session_queue(session_id)) + else: + get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task + done, _ = await asyncio.wait( + [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED ) - # Wait for the next request, coming either from the `requests` iterator or `push_queue` - if anext_task is None: - anext_task = asyncio.create_task(anext(requests)) - if get_push_task is None: - if session_id is not None: - get_push_task = loop.run_in_executor(self._executor, push_queue.get) + if anext_task in done: + request = await anext_task + anext_task = None + elif get_push_task in done: + request = await get_push_task + get_push_task = None else: - get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task - done, _ = await asyncio.wait( - [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED - ) - - if anext_task in done: - request = await anext_task - anext_task = None - elif get_push_task in done: - request = await get_push_task - get_push_task = None - else: - self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") - anext_task.cancel() - get_push_task.cancel() - return + self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") + anext_task.cancel() + get_push_task.cancel() + return except: logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) raise - finally: - if session_id is not None: - push_queue.put(None) # Stop thread for get_push_task - del self._session_queues[session_id] async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: """Directly push activation tensors from one server to another""" @@ -312,8 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) session_id = metadata["session_id"] self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") - - self._session_queues[session_id].put(request) + self._put_into_session_queue(session_id, request) return runtime_pb2.ExpertResponse() async def _push_outputs( diff --git a/src/petals/server/server.py b/src/petals/server/server.py index d061d0a..72db9ce 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -528,23 +528,21 @@ class ModuleContainer(threading.Thread): self.dht, self.module_backends = dht, module_backends self.server_info, self.update_period, self.expiration = server_info, update_period, expiration - self.push_manager = mp.Manager() - self.push_manager.__enter__() - session_queues = self.push_manager.dict() + handler_event_queues = [mp.Queue() for _ in range(num_handlers)] self.conn_handlers = [ TransformerConnectionHandler( dht, self.module_backends, adapters=server_info.adapters, dht_prefix=dht_prefix, - push_manager=self.push_manager, - session_queues=session_queues, + handler_event_queues=handler_event_queues, + handler_index=i, inference_max_length=inference_max_length, request_timeout=request_timeout, session_timeout=session_timeout, step_timeout=step_timeout, ) - for _ in range(num_handlers) + for i in range(num_handlers) ] self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) @@ -607,7 +605,6 @@ class ModuleContainer(threading.Thread): logger.debug("Shutting down connection handlers") for handler in self.conn_handlers: handler.shutdown() - self.push_manager.__exit__(None, None, None) logger.debug(f"Shutting down pools") for pool in self.runtime.pools: From 398a384075d17aad1ded769b876e659d8c15802a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 19 Jul 2023 13:08:52 +0300 Subject: [PATCH 130/168] Inherit bitsandbytes compute dtype correctly (override peft quirk) (#377) --- src/petals/client/routing/sequence_manager.py | 3 +-- src/petals/utils/peft.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 5b1ab3f..9230185 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -212,7 +212,6 @@ class RemoteSequenceManager: end_index: int, *, cache_tokens_needed: Optional[int], - overhead_coeff: float = 1.82, # Backend overhead (empirically measured) overhead_delay: float = 0.018, # Serialization overhead (empirically measured) default_inference_rps: float = 300, # If inference RPS unknown alloc_delay: float = 10, # If not enough cache left, we penalize the edge @@ -266,7 +265,7 @@ class RemoteSequenceManager: inference_rps = span.server_info.inference_rps if inference_rps is None: inference_rps = default_inference_rps - graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), overhead_coeff / inference_rps) + graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), 1.0 / inference_rps) return graph diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index bbad779..23661ae 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -198,6 +198,7 @@ def create_lora_adapter(block, quant_type: QuantType): child.out_features, **kwargs, ) + lora_wrapped_child.compute_dtype = child.compute_dtype else: bias = hasattr(child, "bias") and child.bias is not None lora_wrapped_child = LoraLinear( From 3218534745397dac42823a57bac0bb573e6cacf4 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 15:25:34 +0400 Subject: [PATCH 131/168] Fix --token arg (#378) --- src/petals/cli/run_server.py | 6 +++++- src/petals/server/from_pretrained.py | 6 +++--- src/petals/server/server.py | 4 ++-- src/petals/utils/peft.py | 6 +++--- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index c7264b4..8820dd2 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -25,6 +25,11 @@ def main(): help="path or name of a pretrained model, converted with cli/convert_model.py") group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path") + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()") + group.add_argument("--use_auth_token", action="store_true", dest="token", + help="Read token saved by `huggingface-cli login") + parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve") parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve") parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix") @@ -132,7 +137,6 @@ def main(): parser.add_argument("--mean_balance_check_period", type=float, default=60, help="Check the swarm's balance every N seconds (and rebalance it if necessary)") - parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()") parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType], help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or " "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. " diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 9898759..950746e 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -34,7 +34,7 @@ def load_pretrained_block( config: Optional[PretrainedConfig] = None, torch_dtype: Union[torch.dtype, str] = "auto", revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, ) -> nn.Module: @@ -82,7 +82,7 @@ def _load_state_dict_from_repo( block_prefix: str, *, revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, ) -> StateDict: @@ -125,7 +125,7 @@ def _load_state_dict_from_file( filename: str, *, revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 72db9ce..ccc5292 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -77,7 +77,7 @@ class Server: balance_quality: float = 0.75, mean_balance_check_period: float = 120, mean_block_selection_delay: float = 2.5, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, @@ -409,7 +409,7 @@ class ModuleContainer(threading.Thread): update_period: float, expiration: Optional[float], revision: Optional[str], - token: Optional[str], + token: Optional[Union[str, bool]], quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 23661ae..de48cd2 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,7 +1,7 @@ import contextlib import re import time -from typing import Optional, Sequence +from typing import Optional, Sequence, Union import bitsandbytes as bnb import torch @@ -50,7 +50,7 @@ def get_adapter_from_repo( block_idx: Optional[int] = None, device: Optional[int] = None, *, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, **kwargs, ): config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs) @@ -72,7 +72,7 @@ def load_peft( device: Optional[int] = None, *, revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, From 057a2fb5de0f1bb6050213b7ac53d3e2293dad3a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 19:15:53 +0400 Subject: [PATCH 132/168] Support Llama 2 (#379) --- src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 23 ++++++++++++++--------- src/petals/data_structures.py | 4 +++- src/petals/models/bloom/config.py | 2 ++ src/petals/models/llama/block.py | 8 ++++++-- src/petals/models/llama/config.py | 14 +++++++++++--- src/petals/server/backend.py | 9 ++++++--- src/petals/server/from_pretrained.py | 4 ++++ src/petals/server/reachability.py | 3 +-- src/petals/server/server.py | 26 ++++++++++++++++++++------ src/petals/utils/auto_config.py | 15 +++++++++++---- src/petals/utils/convert_block.py | 9 +++++++-- src/petals/utils/hf_auth.py | 7 +++++++ src/petals/utils/misc.py | 9 --------- src/petals/utils/peft.py | 2 +- 15 files changed, 94 insertions(+), 43 deletions(-) create mode 100644 src/petals/utils/hf_auth.py diff --git a/src/petals/__init__.py b/src/petals/__init__.py index b72776c..53bdd51 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev3" +__version__ = "1.2.0.dev4" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 8820dd2..abd3faf 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -25,6 +25,8 @@ def main(): help="path or name of a pretrained model, converted with cli/convert_model.py") group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path") + parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard") + group = parser.add_mutually_exclusive_group(required=False) group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()") group.add_argument("--use_auth_token", action="store_true", dest="token", @@ -59,16 +61,22 @@ def main(): parser.add_argument('--num_handlers', type=int, default=8, required=False, help='server will use this many processes to handle incoming requests') - parser.add_argument('--min_batch_size', type=int, default=1, - help='Minimum required batch size for all operations (in total tokens)') - parser.add_argument('--max_batch_size', type=int, default=2048, - help='The total number of tokens in the same batch will not exceed this value') parser.add_argument('--prefetch_batches', type=int, default=1, required=False, help='Pre-form this many subsequent batches while GPU is processing the current one') parser.add_argument('--sender_threads', type=int, default=1, required=False, help='Use this many threads to pass results/exceptions from Runtime to Pools') - parser.add_argument('--inference_max_length', type=int, default=2048, - help='Maximum total sequence length permitted per inference, defaults to 16384 tokens') + + parser.add_argument('--inference_max_length', type=int, default=None, + help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' + 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + parser.add_argument('--min_batch_size', type=int, default=1, + help='Minimum required batch size for all operations (in total tokens)') + parser.add_argument('--max_batch_size', type=int, default=None, + help='The total number of tokens in the same batch will not exceed this value. ' + 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + parser.add_argument('--attn_cache_tokens', type=int, default=None, + help='The number of past attention key/value pairs that will be stored between inference steps. ' + 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') @@ -86,9 +94,6 @@ def main(): parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") - parser.add_argument('--attn_cache_tokens', type=int, default=8192, - help='The number of past attention key/value pairs that will be stored between inference steps. ' - 'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).') parser.add_argument('--alloc_timeout', type=float, default=5, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index e3a3e03..38d706f 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -27,12 +27,14 @@ class ServerInfo: state: ServerState throughput: RPS + public_name: Optional[str] = None + version: Optional[str] = None + network_rps: Optional[RPS] = None forward_rps: Optional[RPS] = None inference_rps: Optional[RPS] = None adapters: Sequence[str] = () - version: Optional[str] = None torch_dtype: Optional[str] = None quant_type: Optional[str] = None using_relay: Optional[bool] = None diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py index 23521fc..494c187 100644 --- a/src/petals/models/bloom/config.py +++ b/src/petals/models/bloom/config.py @@ -18,6 +18,8 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM attn_class = BloomAttention block_prefix = "h" + num_key_value_groups = 1 + @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 2f07188..55f659a 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -73,7 +73,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer): ) -> Tuple[torch.Tensor]: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) - key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + key_states = key_states.view( + batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) value_states = value_states.view(*key_states.shape) return (key_states, value_states) @@ -81,7 +83,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer): self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value - value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + value_states = value_states.view( + batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) key_states = key_states.view(*value_states.shape) key_states = key_states.permute(0, 2, 1) return (key_states, value_states) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index b21fa9a..241525a 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -18,13 +18,17 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM attn_class = LlamaAttention block_prefix = "model.layers" + @property + def num_key_value_groups(self): + return self.num_attention_heads // self.num_key_value_heads + @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): logger.info( - "LLaMA is available solely for non-commercial research purposes. " - "Make sure you follow the terms of use: https://bit.ly/llama-license" + "Make sure you follow the LLaMA's terms of use: " + "https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1" ) loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) @@ -34,4 +38,8 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM if not dht_prefix.endswith("-hf"): dht_prefix += "-hf" logger.info(f"Using DHT prefix: {dht_prefix}") - return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + + result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + config = result[0] if isinstance(result, tuple) else result + config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization + return result diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 4220546..d61470a 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -81,6 +81,7 @@ class TransformerBackend(ModuleBackend): head_dim = self.config.hidden_size // self.config.num_attention_heads cache_tensors = [] for device, num_heads in zip(self.module.devices, self.shard_num_heads): + num_heads //= self.config.num_key_value_groups keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device) cache_tensors.extend((keys, values)) @@ -123,8 +124,10 @@ class TransformerBackend(ModuleBackend): """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past""" key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2]) for i in range(len(key_cache)): - key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] - value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim] + key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] + # shape: [batch * num_kv_heads, head_dim, kv_length] + value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] + # shape: [batch * num_kv_heads, kv_length, head_dim] layer_past = tuple(chain(*zip(key_cache, value_cache))) return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past @@ -132,7 +135,7 @@ class TransformerBackend(ModuleBackend): self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int ): """Writes new key/value tensors back into cache, works in-place""" - _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape + _batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]): new_key = new_key.view(*cache_key.shape[:3], new_length) cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 950746e..2a2560b 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -23,6 +23,7 @@ from petals.constants import DTYPE_MAP from petals.server.block_utils import resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for +from petals.utils.hf_auth import always_needs_auth logger = get_logger(__name__) @@ -86,6 +87,9 @@ def _load_state_dict_from_repo( cache_dir: str, max_disk_space: Optional[int] = None, ) -> StateDict: + if always_needs_auth(model_name) and token is None: + token = True + index_file = get_file_from_repo( model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir ) diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index 03e01fc..ff8dd14 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -145,8 +145,7 @@ class ReachabilityProtocol(ServicerBase): async with protocol.serve(common_p2p): await protocol._stop.wait() except Exception as e: - logger.warning(f"Reachability service failed: {repr(e)}") - logger.debug("See detailed traceback below:", exc_info=True) + logger.debug("Reachability service failed:", exc_info=True) if not ready.done(): ready.set_exception(e) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ccc5292..947dbd8 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -50,18 +50,19 @@ class Server: initial_peers: List[str], dht_prefix: Optional[str], converted_model_name_or_path: str, + public_name: Optional[str] = None, throughput: Union[float, str], num_blocks: Optional[int] = None, block_indices: Optional[str] = None, num_handlers: int = 8, + inference_max_length: Optional[int] = None, min_batch_size: int = 1, - max_batch_size: int = 2048, - inference_max_length: int = 2048, + max_batch_size: Optional[int] = None, + attn_cache_tokens: Optional[int] = None, torch_dtype: str = "auto", revision: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, - attn_cache_tokens: int = 8192, alloc_timeout: float = 5, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, @@ -93,8 +94,6 @@ class Server: self.converted_model_name_or_path = converted_model_name_or_path self.num_handlers = num_handlers - self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size - self.inference_max_length = inference_max_length self.compression = compression self.stats_report_interval, self.update_period = stats_report_interval, update_period self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads @@ -177,8 +176,19 @@ class Server: self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") + is_multiquery_attn = self.block_config.num_key_value_groups > 1 + if max_batch_size is None: + max_batch_size = 8192 if is_multiquery_attn else 2048 + if inference_max_length is None: + inference_max_length = 8192 if is_multiquery_attn else 2048 + self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size + self.inference_max_length = inference_max_length + # For attention cache in GPU or RAM + if attn_cache_tokens is None: + attn_cache_tokens = 32768 if is_multiquery_attn else 2048 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens + cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 # For disk cache @@ -222,8 +232,9 @@ class Server: throughput_info = {"throughput": throughput} self.server_info = ServerInfo( state=ServerState.JOINING, - adapters=tuple(adapters), + public_name=public_name, version=petals.__version__, + adapters=tuple(adapters), torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), using_relay=self.dht.client_mode, @@ -642,7 +653,10 @@ class ModuleAnnouncerThread(threading.Thread): self.dht = dht self.server_info = server_info self.memory_cache = memory_cache + self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8 + self.bytes_per_token //= block_config.num_key_value_groups + self.update_period = update_period self.expiration = expiration self.trigger = threading.Event() diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index f587051..13c7298 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -1,8 +1,12 @@ +import os +import re from dataclasses import dataclass -from typing import Optional, Type +from typing import Optional, Type, Union from transformers import AutoConfig, PretrainedConfig, PreTrainedModel +from petals.utils.hf_auth import always_needs_auth + @dataclass class _ModelClasses: @@ -26,8 +30,11 @@ class _AutoDistributedBase: _mapping_field = None # Should be defined in child classes @classmethod - def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig: - config = AutoConfig.from_pretrained(*args, **kwargs) + def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig: + if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs: + kwargs["token"] = True + + config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs) if config.model_type not in _CLASS_MAPPING: raise ValueError(f"Petals does not support model type {config.model_type}") @@ -35,7 +42,7 @@ class _AutoDistributedBase: if proper_cls is None: raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}") - return proper_cls.from_pretrained(*args, **kwargs) + return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs) class AutoDistributedConfig(_AutoDistributedBase): diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index f8a4637..94d3e29 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -2,6 +2,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor parallelism """ import re +from enum import Enum from typing import Optional, Sequence import tensor_parallel as tp @@ -11,12 +12,16 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig -from petals.utils.misc import QuantType - use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) +class QuantType(Enum): + NONE = 0 + INT8 = 1 # 8-bit as in the LLM.int8() paper + NF4 = 2 # 4-bit as in the QLoRA paper + + def convert_block( block: nn.Module, block_index: int, diff --git a/src/petals/utils/hf_auth.py b/src/petals/utils/hf_auth.py new file mode 100644 index 0000000..6446b89 --- /dev/null +++ b/src/petals/utils/hf_auth.py @@ -0,0 +1,7 @@ +import os +from typing import Union + + +def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool: + loading_from_repo = model_name is not None and not os.path.isdir(model_name) + return loading_from_repo and model_name.startswith("meta-llama/Llama-2-") diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index 99b246c..2f67202 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -1,14 +1,5 @@ -from enum import Enum - import torch - -class QuantType(Enum): - NONE = 0 - INT8 = 1 # 8-bit as in the LLM.int8() paper - NF4 = 2 # 4-bit as in the QLoRA paper - - DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index de48cd2..da25623 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -17,8 +17,8 @@ from safetensors.torch import load_file from transformers.utils import get_file_from_repo from petals.server.block_utils import resolve_block_dtype +from petals.utils.convert_block import QuantType from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for -from petals.utils.misc import QuantType logger = get_logger(__name__) From e9a20e7e5300508c0a79edc4b75e4f70caaff1b8 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 20:28:23 +0400 Subject: [PATCH 133/168] Require accelerate>=0.20.3 as transformers do (#383) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7341684..6abd0ee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ python_requires = >=3.8,<3.11 install_requires = torch>=1.12 bitsandbytes==0.40.1.post1 - accelerate>=0.16.0,<0.21.0 + accelerate>=0.20.3,<0.21.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 transformers>=4.31.0,<5.0.0 From b1ff8bdd6c370c16bd68dcdb73d35161daa7c6f9 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 21:13:24 +0400 Subject: [PATCH 134/168] Bump version to 2.0.0.post1 (#384) --- README.md | 19 +++++++++++-------- src/petals/__init__.py | 2 +- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 35d5cba..583df81 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,16 @@

-Generate text with distributed [LLaMA-65B](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md), [Guanaco](https://huggingface.co/timdettmers/guanaco-65b), [BLOOM-176B](https://huggingface.co/bigscience/bloom), or [BLOOMZ](https://huggingface.co/bigscience/bloomz) and fine-tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed [LLaMA 2](https://ai.meta.com/llama/) ([70B](https://huggingface.co/meta-llama/Llama-2-70b-hf), [70B-Chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM -model_name = "enoch/llama-65b-hf" # You can also use "bigscience/bloom" or "bigscience/bloomz" +model_name = "enoch/llama-65b-hf" +# You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf", +# "bigscience/bloom", or "bigscience/bloomz" + tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoDistributedModelForCausalLM.from_pretrained(model_name) # Embeddings & prompts are on your device, transformer blocks are distributed across the Internet @@ -25,7 +28,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🚀  Try now in Colab

-📋 Make sure you follow the model's terms of use (see [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license) licenses). Note that LLaMA is available for non-commercial purposes only, and you have to file a request [here](https://bit.ly/llama-license) to use it in your own projects. +📋 Make sure you follow the model's terms of use (see [LLaMA 2](https://bit.ly/llama2-license), [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license) licenses). 🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. @@ -35,7 +38,7 @@ Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linu ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install git+https://github.com/bigscience-workshop/petals +pip install --upgrade petals python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` @@ -46,7 +49,7 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach python -m petals.cli.run_server --port 31330 enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` -This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. +This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. 🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). @@ -74,8 +77,8 @@ Learning more: ## How does it work? -- Petals runs large language models like [LLaMA-65B](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at 3-4 steps/sec for LLaMA-65B and ≈ 1 step/sec for BLOOM-176B — [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](http://chat.petals.ml) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. +- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. +- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](http://chat.petals.ml) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.

@@ -94,7 +97,7 @@ Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/d ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install git+https://github.com/bigscience-workshop/petals +pip install --upgrade petals ``` If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 53bdd51..8f0a0ec 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev4" +__version__ = "2.0.0.post1" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): From ddcda02b061607a55831556298e06e97956ef418 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Thu, 20 Jul 2023 08:51:17 +0000 Subject: [PATCH 135/168] Hardcode IPs until DNS issues get resolved --- src/petals/__init__.py | 2 +- src/petals/constants.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 8f0a0ec..1dfbfbe 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.0.0.post1" +__version__ = "2.0.0.post2" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/constants.py b/src/petals/constants.py index b04ad03..dc6e349 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -1,6 +1,10 @@ import torch PUBLIC_INITIAL_PEERS = [ + # Temporary IPs until DNS issues get resolved + "/ip4/159.223.29.252/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", + "/ip4/24.144.96.147/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", + # Default DNS addresses "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", From e51e84631d9f15b20fd32c6ec6bb1e2e209548a1 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 20 Jul 2023 19:59:28 +0300 Subject: [PATCH 136/168] Update to petals.dev (#390) Since `petals.ml` DNS record is still unavailable, we're switching everything to https://petals.dev Co-authored-by: Aleksandr Borzunov --- README.md | 8 ++++---- examples/prompt-tuning-sst2.ipynb | 2 +- src/petals/cli/run_server.py | 2 +- src/petals/client/routing/sequence_manager.py | 2 +- src/petals/constants.py | 19 ++++++++++--------- src/petals/server/reachability.py | 4 ++-- 6 files changed, 19 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 583df81..784d0f1 100644 --- a/README.md +++ b/README.md @@ -65,8 +65,8 @@ Basic tutorials: Useful tools and advanced guides: -- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/borzunov/chat.petals.ml) -- [Monitor](http://health.petals.ml) for the public swarm: [source code](https://github.com/borzunov/health.petals.ml) +- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/borzunov/chat.petals.dev) +- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/borzunov/health.petals.dev) - Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) @@ -78,7 +78,7 @@ Learning more: ## How does it work? - Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](http://chat.petals.ml) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. +- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.

@@ -218,5 +218,5 @@ _arXiv preprint arXiv:2209.01188,_ 2022. This project is a part of the BigScience research workshop.

- +

diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 876db8f..9123c1a 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -330,7 +330,7 @@ "id": "51770911" }, "source": [ - "Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](http://health.petals.ml/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!" + "Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](https://health.petals.dev/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!" ] }, { diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index abd3faf..46b1163 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -152,7 +152,7 @@ def main(): "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism") parser.add_argument("--skip_reachability_check", action='store_true', - help="Skip checking this server's reachability via health.petals.ml " + help="Skip checking this server's reachability via health.petals.dev " "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 9230185..f0b0ce0 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -513,7 +513,7 @@ class MissingBlocksError(RuntimeError): def __init__(self, block_indices: Union[int, Sequence[int]]): super().__init__( f"No servers holding blocks {block_indices} are online. " - f"You can check the public swarm's state at http://health.petals.ml " + f"You can check the public swarm's state at https://health.petals.dev " f"If there are not enough servers, please connect your GPU: " f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity " ) diff --git a/src/petals/constants.py b/src/petals/constants.py index dc6e349..d307b81 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -1,17 +1,18 @@ import torch PUBLIC_INITIAL_PEERS = [ - # Temporary IPs until DNS issues get resolved - "/ip4/159.223.29.252/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", - "/ip4/24.144.96.147/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", - # Default DNS addresses - "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", - "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", - "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", - "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", + # IPv4 DNS addresses + "/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", + "/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", + # IPv6 DNS addresses + "/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", + "/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", + # Reserved IPs + "/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", + "/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", ] # The reachability API is currently used only when connecting to the public swarm -REACHABILITY_API_URL = "http://health.petals.ml" +REACHABILITY_API_URL = "https://health.petals.dev" DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index ff8dd14..497ddeb 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -28,7 +28,7 @@ def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float response = r.json() if response["success"]: - logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon") + logger.info("Server is reachable from the Internet. It will appear at https://health.petals.dev soon") return if attempt_no == 0: @@ -37,7 +37,7 @@ def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes") time.sleep(retry_delay) except Exception as e: - logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}") + logger.warning(f"Skipping reachability check because health.petals.dev is down: {repr(e)}") return raise RuntimeError( From d49d9ad0cf2b1d332667dab070efc3aecd8a5c4a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 20 Jul 2023 21:07:00 +0400 Subject: [PATCH 137/168] Bump version to 2.0.0.post3 (#391) --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 1dfbfbe..c696bc6 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.0.0.post2" +__version__ = "2.0.0.post3" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): From b6b3ae964ff425ad94c2fc62572451d1e673c64f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 20 Jul 2023 23:20:15 +0400 Subject: [PATCH 138/168] Fix --attn_cache_tokens default (#392) --- src/petals/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 947dbd8..5cdca46 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -186,7 +186,7 @@ class Server: # For attention cache in GPU or RAM if attn_cache_tokens is None: - attn_cache_tokens = 32768 if is_multiquery_attn else 2048 + attn_cache_tokens = 32768 if is_multiquery_attn else 8192 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 From 6e4ebb94d2b84d8b278b328331c126811f1e0916 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 21 Jul 2023 11:09:24 +0400 Subject: [PATCH 139/168] Fix deadlocks in MemoryCache (#396) - Fix deadlocks in MemoryCache - Set default --alloc_timeout to 1 until the MemoryCache update --- src/petals/cli/run_server.py | 2 +- src/petals/server/memory_cache.py | 45 +++++++++++++------------------ 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 46b1163..a33e233 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -94,7 +94,7 @@ def main(): parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") - parser.add_argument('--alloc_timeout', type=float, default=5, + parser.add_argument('--alloc_timeout', type=float, default=1, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') parser.add_argument('--revision', type=str, default=None, diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index a1e2f26..c2aa192 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -90,7 +90,7 @@ class MemoryCache: logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") yield handles finally: - await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) + self._free(max_alloc_size, alloc_task) @staticmethod def get_allocation_size(*descriptors: TensorDescriptor) -> int: @@ -111,25 +111,19 @@ class MemoryCache: async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): if self.current_size_bytes + alloc_size > self.max_size_bytes: await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) - async with hivemind.utils.enter_asynchronously(self._lock_metadata): + with self._lock_metadata: handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) self.current_size_bytes += alloc_size self.handle_counter += len(handles) # note: this will eventually overflow and it is okay self._pipe_send.send((handles, descriptors)) return handles - async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task): - """ - This method should be called inside asyncio.shield() because: - - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation - - _schedule_free() must finish freeing memory even in case of cancellation - """ - + def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None: if alloc_task.exception() is not None: return handles = alloc_task.result() - async with hivemind.utils.enter_asynchronously(self._lock_metadata): + with self._lock_metadata: self._pipe_send.send((handles, None)) # signal runtime to free these handles self.current_size_bytes -= alloc_size self._memory_freed_event.set() @@ -160,22 +154,21 @@ class MemoryCache: assert os.getpid() == self.runtime_pid # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here - with self._lock_metadata: - # read creation/deletion requests from connection handlers - while self._pipe_recv.poll(): - recv_handles, recv_data = self._pipe_recv.recv() - if recv_data is not None: # create new tensors - assert len(recv_handles) == len(recv_data) - for handle, descr in zip(recv_handles, recv_data): - self._allocated_tensors[handle] = descr.make_zeros() - assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" - else: # delete tensors by handle - for handle in recv_handles: - if handle not in self._allocated_tensors: - logger.warning( - f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" - ) - self._allocated_tensors.pop(handle, None) + # read creation/deletion requests from connection handlers + while self._pipe_recv.poll(): + recv_handles, recv_data = self._pipe_recv.recv() + if recv_data is not None: # create new tensors + assert len(recv_handles) == len(recv_data) + for handle, descr in zip(recv_handles, recv_data): + self._allocated_tensors[handle] = descr.make_zeros() + assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" + else: # delete tensors by handle + for handle in recv_handles: + if handle not in self._allocated_tensors: + logger.warning( + f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" + ) + self._allocated_tensors.pop(handle, None) yield tuple(self._allocated_tensors[handle] for handle in handles) From eb0664b993d6cc0c84e3c3c79b6c3ae152f525fc Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 22 Jul 2023 13:07:43 +0400 Subject: [PATCH 140/168] Support Python 3.11 (#393) --- .github/workflows/run-tests.yaml | 2 +- setup.cfg | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index a81592b..7ec5bf3 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10' ] + python-version: [ '3.8', '3.9', '3.10', '3.11' ] fail-fast: false timeout-minutes: 15 steps: diff --git a/setup.cfg b/setup.cfg index 6abd0ee..417a126 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ classifiers = package_dir = = src packages = find: -python_requires = >=3.8,<3.11 +python_requires = >=3.8 install_requires = torch>=1.12 bitsandbytes==0.40.1.post1 @@ -39,7 +39,7 @@ install_requires = transformers>=4.31.0,<5.0.0 speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8 - hivemind==1.1.8 + hivemind @ git+https://github.com/learning-at-home/hivemind tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 From 8666653cf562519cf38e50ccd6712c3f8ae7908e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 22 Jul 2023 18:27:58 +0400 Subject: [PATCH 141/168] Fix routing through relay, default network RPS, --token, logging, readme (#399) * Hide GeneratorExit in _iterate_inference_steps() * Update README.md about `--public_name` * Use .from_pretrained(..., use_auth_token=token) instead of token=token until it's fully supported across HF libs * Use default network speed 25 Mbit/s * Apply relay penalty in max-throughput routing * Replace RPS with "tokens/sec per block" in logs * Increase default expiration --- README.md | 8 ++- src/petals/client/routing/sequence_manager.py | 12 +++- src/petals/server/from_pretrained.py | 2 +- src/petals/server/handler.py | 2 +- src/petals/server/server.py | 4 +- src/petals/server/throughput.py | 65 +++++++++---------- src/petals/utils/auto_config.py | 8 ++- tests/scripts/remove_old_models.py | 25 ------- 8 files changed, 58 insertions(+), 68 deletions(-) delete mode 100644 tests/scripts/remove_old_models.py diff --git a/README.md b/README.md index 784d0f1..e4bca6e 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,13 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... ### Connect your GPU and increase Petals capacity +Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models! + Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+): ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install --upgrade petals +pip install git+https://github.com/bigscience-workshop/petals python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` @@ -55,6 +57,8 @@ This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.c 💬 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! +🏆 If you host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks! You can specify them with `--public_name YOUR_NAME`. We will show them once your server loads all blocks. + ### Check out tutorials, examples, and more Basic tutorials: @@ -97,7 +101,7 @@ Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/d ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install --upgrade petals +pip install git+https://github.com/bigscience-workshop/petals ``` If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index f0b0ce0..c980412 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -291,7 +291,9 @@ class RemoteSequenceManager: # This is okay since false positives are more costly than false negatives here. return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left - def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: + def _make_sequence_with_max_throughput( + self, start_index: int, end_index: int, *, relay_penalty: float = 0.5 + ) -> List[RemoteSpanInfo]: span_sequence = [] current_index = start_index while current_index < end_index: @@ -299,7 +301,13 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) + span_weights = np.array( + [ + span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty) + for span in candidate_spans + ], + dtype=np.float64, + ) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 2a2560b..bfbf03e 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -40,7 +40,7 @@ def load_pretrained_block( max_disk_space: Optional[int] = None, ) -> nn.Module: if config is None: - config = AutoDistributedConfig.from_pretrained(model_name, token=token) + config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 5d0a3d4..d3776de 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -347,7 +347,7 @@ class TransformerConnectionHandler(ConnectionHandler): anext_task.cancel() get_push_task.cancel() return - except: + except Exception: logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) raise diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 5cdca46..6d5c293 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -104,7 +104,7 @@ class Server: self.block_config = AutoDistributedConfig.from_pretrained( converted_model_name_or_path, - token=token, + use_auth_token=token, revision=revision, ) @@ -117,7 +117,7 @@ class Server: self.dht_prefix = dht_prefix if expiration is None: - expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) + expiration = max(3 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) self.expiration = expiration self.request_timeout = request_timeout diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index d92355e..9e2ad6f 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -96,7 +96,7 @@ def get_server_throughput( throughput = throughput_info["forward_rps"] / average_blocks_used throughput = min(throughput, throughput_info.get("network_rps", math.inf)) throughput_info["throughput"] = throughput - logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks") + logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks") return throughput_info @@ -109,13 +109,10 @@ def measure_throughput_info( quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], ) -> Dict[str, float]: - """Measure network and compute throughput in forward pass tokens per second""" - logger.info( "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" ) - - throughput_info = { + return { "inference_rps": measure_compute_rps( config, device, @@ -136,37 +133,39 @@ def measure_throughput_info( n_steps=10, inference=False, ), + "network_rps": measure_network_rps(config), } - try: - throughput_info["network_rps"] = measure_network_rps(config) - except Exception as e: - logger.info(f"Network throughput is not available: {e}") - return throughput_info - -def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]: - pipe_recv, pipe_send = mp.Pipe(duplex=False) - process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) - process.start() - - if not pipe_recv.poll(timeout): - process.terminate() - raise RuntimeError(f"speedtest did not finish in {timeout} seconds") - network_info = pipe_recv.recv() - if "exception" in network_info: - raise RuntimeError(f"speedtest failed: {network_info['exception']}") +def measure_network_rps( + config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 25e6 +) -> Optional[float]: bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward - network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request - if network_rps == 0: - raise RuntimeError("speedtest has returned network_rps == 0") - - logger.info( - f"Network throughput: {network_rps:.1f} RPS " - f"({network_info['download'] / 1e6:.2f} Mbit/s on download, " - f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)" - ) - return network_rps + try: + pipe_recv, pipe_send = mp.Pipe(duplex=False) + process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) + process.start() + + if not pipe_recv.poll(timeout): + process.terminate() + raise RuntimeError(f"speedtest did not finish in {timeout} seconds") + network_info = pipe_recv.recv() + if "exception" in network_info: + raise RuntimeError(f"speedtest failed: {network_info['exception']}") + + network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request + if network_rps == 0: + raise RuntimeError("speedtest has returned network_rps == 0") + + logger.info( + f"Network throughput: {network_rps:.1f} tokens/sec " + f"({network_info['download'] / 1e6:.2f} Mbit/s on download, " + f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)" + ) + return network_rps + except RuntimeError as e: + logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s") + return default_speed / bits_per_request def _measure_bits_per_second(pipe_send: mp.Pipe): @@ -215,7 +214,7 @@ def measure_compute_rps( devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) logger.info( - f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block " + f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block " f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})" ) return device_rps diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index 13c7298..70f37a3 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -31,8 +31,12 @@ class _AutoDistributedBase: @classmethod def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig: - if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs: - kwargs["token"] = True + if ( + always_needs_auth(model_name_or_path) + and kwargs.get("token") is None + and kwargs.get("use_auth_token") is None + ): + kwargs["use_auth_token"] = True config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs) if config.model_type not in _CLASS_MAPPING: diff --git a/tests/scripts/remove_old_models.py b/tests/scripts/remove_old_models.py deleted file mode 100644 index 598fb3b..0000000 --- a/tests/scripts/remove_old_models.py +++ /dev/null @@ -1,25 +0,0 @@ -import argparse -from datetime import datetime - -from huggingface_hub import delete_repo, list_models - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Remove old testing models from HF hub") - parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained") - parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60) - parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") - parser.add_argument("--dry_run", action="store_true") - - args = parser.parse_args() - - for model in list_models(author=args.author, full=True): - last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ") - - if model.modelId.endswith("-main") or "/test-" not in model.modelId: - continue # remove only test models - - if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated: - if args.dry_run: - print(f"{model.modelId} can be deleted") - else: - delete_repo(repo_id=model.modelId, token=args.use_auth_token) From 30b94ef18b61e19e3582a2edd0ce32d5cadabb3d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 22 Jul 2023 18:49:37 +0400 Subject: [PATCH 142/168] If speedtest fails, assume network speed of 100 Mbit/s (#404) The value is chosen as some safe value below average at https://health.petals.dev/ Note that if a server uses relays, the effective throughput will be further divided by 2 (see #399). --- src/petals/server/throughput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 9e2ad6f..fbea3d2 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -138,7 +138,7 @@ def measure_throughput_info( def measure_network_rps( - config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 25e6 + config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s ) -> Optional[float]: bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward try: From 5af04524dd8bc5ce6a9a11a35ee71714f310aad6 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 22 Jul 2023 22:10:46 +0300 Subject: [PATCH 143/168] Split long sequences into chunks (#403) This PR is designed to avoid OOMs when processing long sequences that happen due to the huge attention logits matrices. Co-authored-by: Alexander Borzunov --- .github/workflows/run-tests.yaml | 3 ++- src/petals/cli/run_server.py | 2 ++ src/petals/server/backend.py | 41 +++++++++++++++++++++++++++++--- src/petals/server/server.py | 5 ++++ tests/test_full_model.py | 12 +++++++--- 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 7ec5bf3..3bccda3 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -37,7 +37,8 @@ jobs: python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log & + --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ + --adapters $ADAPTER_NAME &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index a33e233..8132a39 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -74,6 +74,8 @@ def main(): parser.add_argument('--max_batch_size', type=int, default=None, help='The total number of tokens in the same batch will not exceed this value. ' 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, + help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') parser.add_argument('--attn_cache_tokens', type=int, default=None, help='The number of past attention key/value pairs that will be stored between inference steps. ' 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index d61470a..8b788b0 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -27,7 +27,13 @@ class TransformerBackend(ModuleBackend): _peft_module = None def __init__( - self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs + self, + *args, + config: PretrainedConfig, + memory_cache: MemoryCache, + backend_dtype: torch.dtype, + max_chunk_size_bytes: int, + **kwargs, ): import petals.utils.peft as _peft_module @@ -37,6 +43,8 @@ class TransformerBackend(ModuleBackend): assert isinstance(self.module, TensorParallel) self.config = config self.memory_cache = memory_cache + self.max_chunk_size_bytes = max_chunk_size_bytes + for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" for name, buf in self.module.named_buffers(): @@ -55,6 +63,7 @@ class TransformerBackend(ModuleBackend): ) self.dtype = backend_dtype + self.dtype_bytes = torch.finfo(self.dtype).bits // 8 self.shard_num_heads = [] for shard in self.module.module_shards: for submodule in shard.modules(): @@ -105,14 +114,40 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" + seq_len = hidden_states.shape[1] + with self.memory_cache.use_cache( *inference_info.cache_handles ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): self._reorder_cache_inplace(cache_tensors, hypo_ids) + + # We chunk the inputs so that peak memory for long sequences fits into `autograd_memory` + # reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes` + # is at least 4-6x less than `autograd_memory`. + max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info) + output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) - hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) + for offset in range(0, seq_len, max_chunk_length): + hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :] + output_hidden_states_chunk, new_kvs = self.module.forward( + hidden_states_chunk, layer_past=layer_past, use_cache=True + ) + if seq_len > max_chunk_length: + output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk + else: + output_hidden_states = output_hidden_states_chunk # saves one memcopy + layer_past = new_kvs + self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) - return (hidden_states,) + return (output_hidden_states,) + + def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int: + # We assume that attention logit matrices are the main thing that consumes memory, given that + # the model uses multi-query attention + batch_size, seq_length, hidden_size = hidden_states.shape + worst_case_length = inference_info.prefix_length + seq_length + attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length + return max(1, self.max_chunk_size_bytes // attn_bytes_per_token) def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor): """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 6d5c293..5cb9b91 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -58,6 +58,7 @@ class Server: inference_max_length: Optional[int] = None, min_batch_size: int = 1, max_batch_size: Optional[int] = None, + max_chunk_size_bytes: int = 256 * 1024 * 1024, attn_cache_tokens: Optional[int] = None, torch_dtype: str = "auto", revision: Optional[str] = None, @@ -183,6 +184,7 @@ class Server: inference_max_length = 8192 if is_multiquery_attn else 2048 self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.inference_max_length = inference_max_length + self.max_chunk_size_bytes = max_chunk_size_bytes # For attention cache in GPU or RAM if attn_cache_tokens is None: @@ -312,6 +314,7 @@ class Server: num_handlers=self.num_handlers, min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size, + max_chunk_size_bytes=self.max_chunk_size_bytes, inference_max_length=self.inference_max_length, torch_dtype=self.torch_dtype, cache_dir=self.cache_dir, @@ -412,6 +415,7 @@ class ModuleContainer(threading.Thread): block_indices: List[int], min_batch_size: int, max_batch_size: int, + max_chunk_size_bytes: int, torch_dtype: torch.dtype, cache_dir: str, max_disk_space: int, @@ -477,6 +481,7 @@ class ModuleContainer(threading.Thread): config=block_config, memory_cache=memory_cache, backend_dtype=torch_dtype, + max_chunk_size_bytes=max_chunk_size_bytes, args_schema=( BatchTensorDescriptor( 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression diff --git a/tests/test_full_model.py b/tests/test_full_model.py index acd5e6a..511604b 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -28,7 +28,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f assert isinstance(model, DistributedBloomForCausalLM) assert len(model.transformer.h) == model.config.num_hidden_layers - test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] + test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"] with torch.inference_mode(): parallel_outputs = model.forward(test_inputs).logits @@ -43,8 +43,14 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) for t in range(embs.shape[1]): - recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) - if t == int(embs.shape[1] // 2) and pass_empty_tensors: + if t == 4: + recurrent_outputs.append(sess.step(embs[:, 4:9, :])) + elif 4 < t < 9: + continue + else: + recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) + + if t == 2 and pass_empty_tensors: recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) From c153cba1fa4dc02a8770b7dd85f700cd27f4ec37 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 23 Jul 2023 00:35:19 +0400 Subject: [PATCH 144/168] Add Llama 2, WSL instructions to readme (#406) --- README.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index e4bca6e..1fc4475 100644 --- a/README.md +++ b/README.md @@ -28,15 +28,17 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🚀  Try now in Colab

-📋 Make sure you follow the model's terms of use (see [LLaMA 2](https://bit.ly/llama2-license), [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license) licenses). +🦙 **Want to run LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). -🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. +📋 **Terms of use.** Make sure you follow the model license (see the ones for [LLaMA 2](https://bit.ly/llama2-license), [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license)). + +🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. ### Connect your GPU and increase Petals capacity Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models! -Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+): +🐍 **Linux + Anaconda.** Run these commands: ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia @@ -44,20 +46,24 @@ pip install git+https://github.com/bigscience-workshop/petals python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` -Or run our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)): +🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows). + +🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image: ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ python -m petals.cli.run_server --port 31330 enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` -This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. +These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. + +🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add the `--token YOUR_TOKEN` argument to the commands above. -🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). +💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! -💬 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! +🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). -🏆 If you host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks! You can specify them with `--public_name YOUR_NAME`. We will show them once your server loads all blocks. +🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`. ### Check out tutorials, examples, and more From 48c6b6d9637f8f84a02dd85e43e7bd3790a16b1b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 23 Jul 2023 00:41:41 +0400 Subject: [PATCH 145/168] Update README.md (#407) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1fc4475..685b93d 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. -🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add the `--token YOUR_TOKEN` argument to the commands above. +🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add the `--token YOUR_TOKEN` argument to the `petals.cli.run_server` command above. 💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! From ffb20b585c984c124b7292c73361b78fdae741c6 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 23 Jul 2023 13:08:07 +0400 Subject: [PATCH 146/168] Update commands for hosting Llama 2 in readme (#409) --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 685b93d..58a41c2 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,11 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. -🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add the `--token YOUR_TOKEN` argument to the `petals.cli.run_server` command above. +🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`: + +```bash +python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE +``` 💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! From fd19c21859764a9f2f8c7ae5f5846359421f8105 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 23 Jul 2023 17:22:04 +0400 Subject: [PATCH 147/168] Update --update_period and --expiration defaults (#410) --- src/petals/cli/run_server.py | 2 +- src/petals/server/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 8132a39..d4fc999 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -110,7 +110,7 @@ def main(): 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') - parser.add_argument('--update_period', type=float, required=False, default=60, + parser.add_argument('--update_period', type=float, required=False, default=120, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, help='DHT entries will expire after this many seconds') diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 5cb9b91..fec6e82 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -118,7 +118,7 @@ class Server: self.dht_prefix = dht_prefix if expiration is None: - expiration = max(3 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) + expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) self.expiration = expiration self.request_timeout = request_timeout From f3fafd14a4c00448c67d028c6065e2bef2520fa5 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 23 Jul 2023 18:45:19 +0400 Subject: [PATCH 148/168] Bump version to 2.0.1 (#411) --- setup.cfg | 4 ++-- src/petals/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 417a126..7c04686 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,8 +38,8 @@ install_requires = tokenizers>=0.13.3 transformers>=4.31.0,<5.0.0 speedtest-cli==2.1.3 - pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8 - hivemind @ git+https://github.com/learning-at-home/hivemind + pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet + hivemind==1.1.9 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index c696bc6..d2f91bc 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.0.0.post3" +__version__ = "2.0.1" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): From 8072cd9d1b23fb7b3868f7b4ce7cd7b017d251ed Mon Sep 17 00:00:00 2001 From: Guocheng Date: Tue, 25 Jul 2023 22:21:15 +0800 Subject: [PATCH 149/168] Fix stale link (#418) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 58a41c2..ecfb615 100644 --- a/README.md +++ b/README.md @@ -79,8 +79,8 @@ Basic tutorials: Useful tools and advanced guides: -- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/borzunov/chat.petals.dev) -- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/borzunov/health.petals.dev) +- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev) +- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev) - Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) From cdc0f7065338debc5ca3c304a2e43d57534f0b40 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 30 Jul 2023 16:07:38 +0200 Subject: [PATCH 150/168] Add Discord badge and more Discord links to readme (#422) --- README.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ecfb615..0c44ff6 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,11 @@


Run large language models at home, BitTorrent-style.
- Fine-tuning and inference up to 10x faster than offloading

-
+ Fine-tuning and inference up to 10x faster than offloading +

+ + +

Generate text with distributed [LLaMA 2](https://ai.meta.com/llama/) ([70B](https://huggingface.co/meta-llama/Llama-2-70b-hf), [70B-Chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: @@ -34,6 +37,8 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. +💬 **Any questions?** Ping us in [our Discord](https://discord.gg/J29mCBNBvm)! + ### Connect your GPU and increase Petals capacity Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models! @@ -57,14 +62,14 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. +💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! + 🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`: ```bash python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE ``` -💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! - 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`. From 44fefa5e54e457d915d7c5d84a8506a43e5a0e53 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 30 Jul 2023 21:31:19 +0200 Subject: [PATCH 151/168] Add connect_timeout (#423) --- src/petals/client/inference_session.py | 2 +- src/petals/client/remote_forward_backward.py | 29 ++++++++++--------- src/petals/client/routing/sequence_manager.py | 1 + src/petals/client/sequential_autograd.py | 4 +-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 5e14d8a..d14c4a2 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -75,7 +75,7 @@ class _ServerInferenceSession: inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), - config.request_timeout, + config.connect_timeout, ) return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index df97db1..a116822 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -13,52 +13,53 @@ from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.streaming import split_for_streaming +from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.data_structures import ModuleUID, RPCInfo async def _forward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), - timeout=timeout, + timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in outputs.tensors] async def _backward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), - timeout=timeout, + timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] async def _forward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) - outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout) - outputs = aiter_with_timeout(outputs, timeout) + outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout) + outputs = aiter_with_timeout(outputs, config.request_timeout) return await deserialize_tensor_stream(msg.tensors async for msg in outputs) async def _backward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) - grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout) - grad_inputs = aiter_with_timeout(grad_inputs, timeout) + grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout) + grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout) return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs) @@ -67,7 +68,7 @@ async def run_remote_forward( stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, - timeout: float, + config: SequenceManagerConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Tuple[torch.Tensor, ...]: @@ -110,7 +111,7 @@ async def run_remote_forward( size = sum(t.element_size() * t.nelement() for t in inputs) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) @@ -121,7 +122,7 @@ async def run_remote_backward( inputs: torch.Tensor, grad_outputs: List[torch.Tensor], *extra_tensors: torch.Tensor, - timeout: float, + config: SequenceManagerConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Sequence[torch.Tensor]: @@ -153,5 +154,5 @@ async def run_remote_backward( size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) return deserialized_grad_inputs diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index c980412..a7a0f1d 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -40,6 +40,7 @@ class SequenceManagerConfig: allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers use_server_to_server: bool = True # Use direct server-to-server communication + connect_timeout: float = 5 # timeout for opening a connection request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 425fdb7..ebc56b4 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -76,7 +76,7 @@ async def sequential_forward( stub, sequence_manager.rpc_info, *inputs_and_prompts, - timeout=sequence_manager.config.request_timeout, + config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) @@ -161,7 +161,7 @@ async def sequential_backward( inputs, grad_outputs, prompts[span.start : span.end], - timeout=sequence_manager.config.request_timeout, + config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) grad_outputs = [grad_outputs] From 6a1b8a6a9066f7ffc0b8b24c5e357ac1878dea8b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 31 Jul 2023 01:23:56 +0200 Subject: [PATCH 152/168] Add Stable Beluga 2 to readme (#424) --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 0c44ff6..982a1b5 100644 --- a/README.md +++ b/README.md @@ -8,15 +8,15 @@

-Generate text with distributed [LLaMA 2](https://ai.meta.com/llama/) ([70B](https://huggingface.co/meta-llama/Llama-2-70b-hf), [70B-Chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed [LLaMA 2 (70B)](https://huggingface.co/meta-llama/Llama-2-70b-hf), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM -model_name = "enoch/llama-65b-hf" +model_name = "stabilityai/StableBeluga2" # You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf", -# "bigscience/bloom", or "bigscience/bloomz" +# repos with LLaMA-65B, "bigscience/bloom", or "bigscience/bloomz" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoDistributedModelForCausalLM.from_pretrained(model_name) @@ -33,7 +33,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🦙 **Want to run LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). -📋 **Terms of use.** Make sure you follow the model license (see the ones for [LLaMA 2](https://bit.ly/llama2-license), [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license)). +📋 **Terms of use.** Make sure you follow the model license (see [LLaMA 2](https://bit.ly/llama2-license), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2/blob/main/LICENSE.txt), [LLaMA](https://bit.ly/llama-license), and [BLOOM](https://bit.ly/bloom-license)). 🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. @@ -48,7 +48,7 @@ Petals is a community-run system — we rely on people sharing their GPUs. Y ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install git+https://github.com/bigscience-workshop/petals -python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b +python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 ``` 🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows). @@ -57,12 +57,10 @@ python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanac ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ - python -m petals.cli.run_server --port 31330 enoch/llama-65b-hf --adapters timdettmers/guanaco-65b + python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16 ``` -These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. - -💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! +These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. 🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`: @@ -70,6 +68,8 @@ These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingf python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE ``` +💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! + 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`. From 351e96bc469b260c86d30cd823a262a0b71be66e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 3 Aug 2023 02:00:43 +0200 Subject: [PATCH 153/168] Penalize servers that use relays during rebalancing (#428) Servers accessible only via relays may introduce issues if they are the only type of servers holding certain blocks. Specifically, a connection to such servers may be unstable or opened after a certain delay. This PR changes their self-reported throughput, so that the rebalancing algorithm prefers to put directly available servers for hosting each block. --- src/petals/client/routing/sequence_manager.py | 12 ++---------- src/petals/server/server.py | 15 ++++++++------- src/petals/server/throughput.py | 7 ++++++- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index a7a0f1d..b19d468 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -292,9 +292,7 @@ class RemoteSequenceManager: # This is okay since false positives are more costly than false negatives here. return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left - def _make_sequence_with_max_throughput( - self, start_index: int, end_index: int, *, relay_penalty: float = 0.5 - ) -> List[RemoteSpanInfo]: + def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: span_sequence = [] current_index = start_index while current_index < end_index: @@ -302,13 +300,7 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array( - [ - span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty) - for span in candidate_spans - ], - dtype=np.float64, - ) + span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fec6e82..5c47270 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -83,7 +83,7 @@ class Server: quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, - dht_client_mode: Optional[bool] = None, + reachable_via_relay: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, adapters: Sequence[str] = (), @@ -129,20 +129,20 @@ class Server: for block_index in range(self.block_config.num_hidden_layers) ] - if dht_client_mode is None: + if reachable_via_relay is None: is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) - dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer - logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}") + reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer + logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}") self.dht = DHT( initial_peers=initial_peers, start=True, num_workers=self.block_config.num_hidden_layers, use_relay=use_relay, use_auto_relay=use_auto_relay, - client_mode=dht_client_mode, + client_mode=reachable_via_relay, **kwargs, ) - self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None + self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: @@ -227,6 +227,7 @@ class Server: num_blocks=num_blocks, quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, + reachable_via_relay=reachable_via_relay, force_eval=(throughput == "eval"), cache_dir=cache_dir, ) @@ -239,7 +240,7 @@ class Server: adapters=tuple(adapters), torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), - using_relay=self.dht.client_mode, + using_relay=reachable_via_relay, **throughput_info, ) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index fbea3d2..d977611 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -41,6 +41,8 @@ def get_server_throughput( num_blocks: int, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], + reachable_via_relay: bool, + relay_penalty: float = 0.2, force_eval: bool = False, cache_dir: Optional[str] = None, ) -> Dict[str, float]: @@ -94,7 +96,10 @@ def get_server_throughput( # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2 average_blocks_used = (num_blocks + 1) / 2 throughput = throughput_info["forward_rps"] / average_blocks_used - throughput = min(throughput, throughput_info.get("network_rps", math.inf)) + + network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1) + throughput = min(throughput, network_rps) + throughput_info["throughput"] = throughput logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks") From a1f7791d5e52c49d255fee4e571db49c20e0c8c1 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 3 Aug 2023 02:17:07 +0200 Subject: [PATCH 154/168] Fix petals.utils.ping for servers with client-mode DHT (#430) Fix #429. --- src/petals/__init__.py | 2 +- src/petals/utils/ping.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index d2f91bc..c9f8223 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.0.1" +__version__ = "2.0.1.post1" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/utils/ping.py b/src/petals/utils/ping.py index 4245bf4..5e3e775 100644 --- a/src/petals/utils/ping.py +++ b/src/petals/utils/ping.py @@ -24,7 +24,10 @@ async def ping( start_time = time.perf_counter() await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout) return time.perf_counter() - start_time - except Exception: + except Exception as e: + if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays) + return time.perf_counter() - start_time + logger.debug(f"Failed to ping {peer_id}:", exc_info=True) return math.inf From d0b5af34cd9fdeaccdecd5692dabdb030fd38c56 Mon Sep 17 00:00:00 2001 From: Vadim Peretokin Date: Sun, 6 Aug 2023 14:47:21 +0200 Subject: [PATCH 155/168] Fix typo and make blocks message more informative (#437) The message really doesn't tell me much as a user, since I never touched update_period to begin with: ``` Aug 06 09:43:07.287 [WARN] [petals.server.server.run:701] Declaring blocs to DHT takes more than --update_period, consider increasing it ``` Made it better and more informative. --- src/petals/server/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 5c47270..7772fa6 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -698,7 +698,9 @@ class ModuleAnnouncerThread(threading.Thread): delay = self.update_period - (time.perf_counter() - start_time) if delay < 0: - logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it") + logger.warning( + f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})" + ) self.trigger.wait(max(delay, 0)) self.trigger.clear() From 679397df0c8fd374084eaa4b55b63e7279600053 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 6 Aug 2023 17:11:49 +0400 Subject: [PATCH 156/168] Update Discord links from channels to forums (#440) As our Discord community growths, we found it difficult to look for open and resolved issues in **#running-a-client** and **#running-a-server** channels, as well as navigate through interleaving conversations happening there. That's why we recreated these channels as Discord forums, where different discussions are separated into different posts. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 982a1b5..f042051 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. -💬 **Any questions?** Ping us in [our Discord](https://discord.gg/J29mCBNBvm)! +💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! ### Connect your GPU and increase Petals capacity @@ -68,7 +68,7 @@ These commands will host a part of [Stable Beluga 2](https://huggingface.co/stab python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE ``` -💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! +💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)! 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). From b58141ef667534a4db4d3aa6905164484361d438 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 6 Aug 2023 18:55:22 +0400 Subject: [PATCH 157/168] Remove distracting links from readme (#441) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f042051..ea7919a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@

-Generate text with distributed [LLaMA 2 (70B)](https://huggingface.co/meta-llama/Llama-2-70b-hf), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer @@ -96,8 +96,8 @@ Learning more: ## How does it work? -- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. +- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning. +- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.

From 32fbab5192da5ee26ee2a0a7eade6d101690a70d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 02:22:21 +0400 Subject: [PATCH 158/168] Remove deprecated comment in fine-tuning notebook (#443) --- examples/prompt-tuning-sst2.ipynb | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 9123c1a..b6f2d8a 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -92,9 +92,6 @@ }, "outputs": [], "source": [ - "# Choose a model you'd like to prompt-tune. We recommend starting with\n", - "# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n", - "# The code below uses LLaMA-65B.\n", "MODEL_NAME = \"enoch/llama-65b-hf\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", From 593d980ad8676fa0b43d8cc266b0dbb2052c31ec Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 02:33:42 +0400 Subject: [PATCH 159/168] Use bitsandbytes 0.41.1 (#442) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7c04686..1560f2b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.8 install_requires = torch>=1.12 - bitsandbytes==0.40.1.post1 + bitsandbytes==0.41.1 accelerate>=0.20.3,<0.21.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 From ac9b5467067735a885f7a4dfad689a2a2dc7f594 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 7 Aug 2023 14:32:51 +0300 Subject: [PATCH 160/168] [Refactor] extract block forward, backward and inference into a separate file (#435) This PR does not change any functionality. It merely moves stuff around. List of changes: handler.py/_rpc_forward became block_methods/rpc_forward handler.py/_rpc_backward became block_methods/rpc_backward the math bits of rpc_inference were extracted into block_methods/iterate_rpc_inference --------- Co-authored-by: Your Name Co-authored-by: artek0chumak Co-authored-by: Aleksandr Borzunov --- src/petals/server/block_functions.py | 195 +++++++++++++++++++++++++++ src/petals/server/handler.py | 192 +++----------------------- 2 files changed, 214 insertions(+), 173 deletions(-) create mode 100644 src/petals/server/block_functions.py diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py new file mode 100644 index 0000000..9208deb --- /dev/null +++ b/src/petals/server/block_functions.py @@ -0,0 +1,195 @@ +""" +This module implements server-side computations on served blocks: forward, backward and inference; used by handler +""" +from __future__ import annotations + +from typing import AsyncIterator, Optional, Sequence, Tuple, Union + +import torch +from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor +from hivemind.moe.expert_uid import ExpertUID +from hivemind.proto import runtime_pb2 +from hivemind.utils.nested import nested_flatten + +from petals.data_structures import InferenceMetadata +from petals.server.backend import TransformerBackend +from petals.server.memory_cache import Handle +from petals.server.task_pool import PrioritizedTaskPool +from petals.server.task_prioritizer import TaskPrioritizerBase +from petals.utils.misc import DUMMY, is_dummy + + +async def run_rpc_forward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> torch.Tensor: + """ + Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream + + :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors + :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) + :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass + :returns: hidden states after the last layer [batch_size, seq_length, hid_size] + """ + hidden_states, prompts = flat_tensors + dtype = requested_backends[0].dtype + # check parse input tensors and cast dtypes + hidden_states = hidden_states.to(dtype) + assert hidden_states.ndim == 3 + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a chain of requested backends + for backend, prompt in zip(requested_backends, prompts): + if not is_dummy(prompt): + hidden_states[:, : prompt.shape[1]] += prompt + + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + hidden_states, points=points / len(requested_backends), backend=backend, type="forward" + ) + (hidden_states,) = await backend.forward_pool.submit_task( + hidden_states, + active_adapter, + priority=priority, + ) + assert isinstance(hidden_states, torch.Tensor) + assert ( + hidden_states.ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + + return hidden_states + + +async def run_rpc_backward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: + inputs, grad_outputs, prompts = flat_tensors + # Cast inputs & grad outputs to backend dtype + inputs = inputs.to(requested_backends[0].dtype) + grad_outputs = grad_outputs.to(requested_backends[-1].dtype) + + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a forward chain to collect intermediate inputs + # Note that we do not forward for the last module since we do not need its output + inter_inputs = [] + for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): + assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + if not is_dummy(prompt): + inputs[:, : prompt.shape[1]] += prompt + inter_inputs.append(inputs) + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + ) + (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) + + assert isinstance(inputs, torch.Tensor) + + if not is_dummy(prompts[-1]): + inputs[:, : prompts[-1].shape[1]] += prompts[-1] + inter_inputs.append(inputs) + + assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" + grad_prompts_reversed = [] + # Run a chain of requested backends + for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + ) + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) + + assert isinstance(grad_outputs, torch.Tensor) + if not is_dummy(prompt): + grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) + + grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY + return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape + + +async def iterate_rpc_inference( + requested_uids: Sequence[ExpertUID], + requested_backends: Sequence[TransformerBackend], + active_adapter: Optional[str], + input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]], + cache_handles: Sequence[Sequence[Handle]], + max_length: int, + prioritizer: TaskPrioritizerBase, + points: int, +) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: + assert len(cache_handles) == len(requested_backends) + + prefix_length = 0 + point_per_piece = points / max_length if max_length > 0 else 0.0 + + async for request, step_metadata in input_iterator: + hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) + + # Cast inputs to backend dtype + hidden_states = hidden_states.to(requested_backends[0].dtype) + assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" + + # parse deep prompts (optional argument) + has_prompts = prompts is not None and not is_dummy(prompts) + if not has_prompts: + prompts = [None] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] + + if not (len(requested_backends) == len(prompts)): + raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") + + length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) + if prefix_length + length_increment > max_length: + raise ValueError( + f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" + f" exceeds pre-allocated maximum {max_length}" + ) + + priority = prioritizer.prioritize( + hidden_states, + hypo_ids, + points=point_per_piece, + requested_uids=requested_uids, + type="inference", + ) + + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) + for uid, handles in zip(requested_uids, cache_handles) + ) + + if hidden_states.numel() == 0: + pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache + else: + assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" + (hidden_states,) = await requested_backends[0].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + ) + + # serialize and send last layer outputs + output_tensors = [ + serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) + for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) + ] + can_push = not has_prompts + yield output_tensors, can_push + + # prepare for next step + prefix_length += length_increment diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d3776de..b9be294 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -6,7 +6,7 @@ import multiprocessing as mp import sys from enum import Enum from itertools import chain -from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple import torch from async_timeout import timeout @@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID from petals.server.backend import TransformerBackend +from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward from petals.server.memory_cache import Handle -from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase -from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__name__) @@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") if not requested_uids: @@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler): f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}" ) - point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 - prefix_length = 0 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: - assert len(cache_handles) == len(requested_backends) - first_request = request background_tasks = set() - async for request, metadata in self._iterate_inference_steps( - first_request, requests, session_id, requested_uids, context + async for output_tensors, can_push in iterate_rpc_inference( + requested_uids=requested_uids, + requested_backends=requested_backends, + active_adapter=self._get_active_adapter(metadata), + input_iterator=self._iterate_inference_steps( + request, requests, session_id, requested_uids, context + ), + cache_handles=cache_handles, + max_length=max_length, + prioritizer=self._prioritizer, + points=points, ): - hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) - - # Cast inputs to backend dtype - hidden_states = hidden_states.to(requested_backends[0].dtype) - assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" - - # parse deep prompts (optional argument) - has_prompts = prompts is not None and not is_dummy(prompts) - if not has_prompts: - prompts = [None] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] - - if not (len(requested_backends) == len(prompts)): - raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") - - length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) - if prefix_length + length_increment > max_length: - raise ValueError( - f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" - f" exceeds pre-allocated maximum {max_length}" - ) - - priority = self._prioritizer.prioritize( - hidden_states, - hypo_ids, - points=point_per_piece, - requested_uids=requested_uids, - type="inference", - ) - - inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) - for uid, handles in zip(requested_uids, cache_handles) - ) - - if hidden_states.numel() == 0: - pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - else: - assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" - (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority - ) - - # serialize and send last layer outputs - output_tensors = [ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip( - (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) - ) - ] - if not has_prompts: + if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) background_tasks.add(task) # Keep reference until it is done to save it from GC task.add_done_callback(background_tasks.discard) yield runtime_pb2.ExpertResponse(tensors=output_tensors) - # prepare for next step - prefix_length += length_increment finally: self._log_request("rpc_inference.close", requested_uids, context) @@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler): result.update(block_info) return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) - - -async def _rpc_forward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> torch.Tensor: - """ - Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream - - :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors - :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) - :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass - :returns: hidden states after the last layer [batch_size, seq_length, hid_size] - """ - hidden_states, prompts = flat_tensors - dtype = requested_backends[0].dtype - # check parse input tensors and cast dtypes - hidden_states = hidden_states.to(dtype) - assert hidden_states.ndim == 3 - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a chain of requested backends - for backend, prompt in zip(requested_backends, prompts): - if not is_dummy(prompt): - hidden_states[:, : prompt.shape[1]] += prompt - - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - hidden_states, points=points / len(requested_backends), backend=backend, type="forward" - ) - (hidden_states,) = await backend.forward_pool.submit_task( - hidden_states, - active_adapter, - priority=priority, - ) - assert isinstance(hidden_states, torch.Tensor) - assert ( - hidden_states.ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - - return hidden_states - - -async def _rpc_backward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - inputs, grad_outputs, prompts = flat_tensors - # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) - grad_outputs = grad_outputs.to(requested_backends[-1].dtype) - - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a forward chain to collect intermediate inputs - # Note that we do not forward for the last module since we do not need its output - inter_inputs = [] - for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" - if not is_dummy(prompt): - inputs[:, : prompt.shape[1]] += prompt - inter_inputs.append(inputs) - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" - ) - (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) - - assert isinstance(inputs, torch.Tensor) - - if not is_dummy(prompts[-1]): - inputs[:, : prompts[-1].shape[1]] += prompts[-1] - inter_inputs.append(inputs) - - assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" - grad_prompts_reversed = [] - # Run a chain of requested backends - for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" - ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) - - assert isinstance(grad_outputs, torch.Tensor) - if not is_dummy(prompt): - grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) - - grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY - return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape From 00d48dcbe12ec26a8fd8aaab37030f994ed84555 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 19:47:22 +0400 Subject: [PATCH 161/168] Override float32 in config to bfloat16 (#431) --- README.md | 4 ++-- src/petals/server/block_utils.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ea7919a..aa93a43 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Petals is a community-run system — we rely on people sharing their GPUs. Y ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install git+https://github.com/bigscience-workshop/petals -python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 +python -m petals.cli.run_server stabilityai/StableBeluga2 ``` 🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows). @@ -57,7 +57,7 @@ python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16 ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ - python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16 + python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 ``` These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures. diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index eb5300e..effce82 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype] """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" if dtype not in ("auto", None): return dtype - if config.torch_dtype not in ("auto", None): + if config.torch_dtype not in ("auto", None, torch.float32): + # If config specifies float32, we override it to the default dtype below return config.torch_dtype return torch.bfloat16 From 2a150770a47d2ee8fd9f9bf129a54d1afdf29e64 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 21:43:21 +0400 Subject: [PATCH 162/168] Prefer longer servers for fine-tuning, exclude unreachable (#448) We choose longer servers to minimize the number of hops but leave some randomization to distribute the load. We also exclude servers known to be unreachable. --- src/petals/client/routing/sequence_manager.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index b19d468..7328cdc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -50,7 +50,7 @@ class SequenceManagerConfig: ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) - max_pinged: int = 5 # max servers to ping from each sequence side, per update + max_pinged: int = 3 # max servers to ping from each sequence side, per update ping_timeout: float = 2 # max time to wait for pings, per update @@ -293,6 +293,8 @@ class RemoteSequenceManager: return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: + client_server_rtts = self.ping_aggregator.to_dict() + span_sequence = [] current_index = start_index while current_index < end_index: @@ -300,7 +302,13 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) + # We choose longer servers to minimize the number of hops but leave some randomization + # to distribute the load. We also exclude servers known to be unreachable. + eps = 1e-6 + span_weights = np.array( + [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans], + dtype=np.float64, + ) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end @@ -361,9 +369,13 @@ class RemoteSequenceManager: self.state.sequence_info.update_(new_block_infos) first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]] + middle_servers = [ + span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans + ] last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]] pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged)) + pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged)) pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged)) self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout) From df6fdd2d0b2c138316d40a519b9ef6078fe04a20 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 8 Aug 2023 04:59:55 +0400 Subject: [PATCH 163/168] Force using --new_swarm instead of empty --initial_peers (#451) This prohibits passing `--initial_peers` without arguments, since it's likely to be a side-effect from `--initial_peers $INITIAL_PEERS` with the env var not set. Users should use `--new_swarm` for that, as explained in the private swarm tutorial. --- src/petals/cli/run_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index d4fc999..c82ff44 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -122,7 +122,7 @@ def main(): help="Timeout (in seconds) for waiting the next step's inputs inside an inference session") group = parser.add_mutually_exclusive_group() - group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS, + group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS, help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm') group.add_argument('--new_swarm', action='store_true', help='Start a new private swarm (i.e., do not connect to any initial peers)') From 8c546d988a4205d440a2625523f0a6f69ee24e1c Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 8 Aug 2023 19:10:27 +0400 Subject: [PATCH 164/168] Test Llama, rebalancing, throughput eval, and all CLI scripts (#452) This PR extends CI to: 1. Test Llama code using [TinyLlama-v0](https://huggingface.co/Maykeye/TinyLLama-v0). 2. Test rebalancing (sets up a situation where the 1st server needs to change its original position). 3. Check if benchmark scripts run (in case someone breaks its code). Note that the benchmark results are meaningless here (since they're measured on a tiny swarm of CPU servers, with low `--n_steps`). 4. Test `petals.cli.run_dht`. 5. Increase swap space and watch free RAM (a common issue is that actions are cancelled without explanation if there's not enough RAM - so it's a useful reminder + debug tool). 6. Fix flapping tests for bloom-560m by increasing tolerance. Other minor changes: fix `--help` messages to show defaults, fix docs, tune rebalancing constants. --- .github/workflows/run-tests.yaml | 93 ++++++++++++++++++++++-------- benchmarks/benchmark_forward.py | 18 +++--- benchmarks/benchmark_inference.py | 14 ++--- benchmarks/benchmark_training.py | 24 ++++---- src/petals/cli/run_dht.py | 8 ++- src/petals/cli/run_server.py | 2 +- src/petals/server/server.py | 6 +- tests/{test.id => bootstrap.id} | Bin tests/server2.id | Bin 0 -> 1197 bytes tests/test_aux_functions.py | 3 + tests/test_block_exact_match.py | 4 +- tests/test_chained_calls.py | 6 +- tests/test_full_model.py | 57 +++++++++--------- tests/test_remote_sequential.py | 12 ++-- tests/test_sequence_manager.py | 4 +- tests/test_server_stats.py | 8 ++- tests/test_tensor_parallel.py | 5 +- 17 files changed, 161 insertions(+), 103 deletions(-) rename tests/{test.id => bootstrap.id} (100%) create mode 100644 tests/server2.id diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 3bccda3..735fd2a 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -10,10 +10,20 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10', '3.11' ] + include: + - { model: 'bigscience/bloom-560m', python-version: '3.8' } + - { model: 'bigscience/bloom-560m', python-version: '3.9' } + - { model: 'bigscience/bloom-560m', python-version: '3.10' } + - { model: 'bigscience/bloom-560m', python-version: '3.11' } + - { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' } + - { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' } fail-fast: false timeout-minutes: 15 steps: + - name: Increase swap space + uses: pierotofy/set-swap-space@master + with: + swap-size-gb: 10 - name: Checkout uses: actions/checkout@v3 - name: Set up Python @@ -31,44 +41,77 @@ jobs: pip install .[dev] - name: Test run: | - export MODEL_NAME=bigscience/bloom-560m - export REF_NAME=bigscience/bloom-560m - export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft - - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ - --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ - --adapters $ADAPTER_NAME &> server1.log & - SERVER1_PID=$! + export MODEL_NAME="${{ matrix.model }}" + export REF_NAME="${{ matrix.model }}" + export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" + export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" + + # [Step 1] Watch free RAM (lack of RAM is a common issue in CI) + + bash -c 'while true; do free -h && sleep 30s; done' & + RAM_WATCH_PID=$! - sleep 5 # wait for the first server to initialize DHT + # [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) + + python -m petals.cli.run_dht \ + --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log & + BOOTSTRAP_PID=$! export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g - # ^-- server 1 multiaddr is determined by --identity and --host_maddrs + # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log & - SERVER2_PID=$! + sleep 5 # wait for DHT init + + python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ + --mean_balance_check_period 10 \ + --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log & + SERVER1_PID=$! + # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there - sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve + sleep 10 # wait for the 1st server to choose blocks - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log & + python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \ + --identity_path tests/server2.id \ + --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log & + SERVER2_PID=$! + + python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \ + --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ + --initial_peers $INITIAL_PEERS --throughput auto &> server3.log & SERVER3_PID=$! + # ^-- chunking test - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log & + python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \ + --initial_peers $INITIAL_PEERS --throughput auto &> server4.log & SERVER4_PID=$! + # ^-- tensor parallelism test (not compatible with adapters yet) - tail -n 100 -f server*.log & + sleep 5 # wait for the log files to appear + + tail -n 100 -f bootstrap.log server*.log & LOGGER_PID=$! - sleep 30 # wait for servers to download layers - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init + sleep 30 # wait for servers to eval throughput, download layers, and rebalance + kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init + + # [Step 3] Run PyTest pytest tests --durations=0 --durations-min=1.0 -v - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests + # [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) + + python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 + python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 --batch_size 3 --n_steps 1 + python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls + python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm + + # [Step 5] Clean up + + kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests - kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID + kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID echo "Done!" diff --git a/benchmarks/benchmark_forward.py b/benchmarks/benchmark_forward.py index e95c5ec..bf547ec 100755 --- a/benchmarks/benchmark_forward.py +++ b/benchmarks/benchmark_forward.py @@ -15,15 +15,15 @@ logger = get_logger() def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="bigscience/bloom") - parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) - parser.add_argument("--torch_dtype", type=str, default="bfloat16") - parser.add_argument("--n_processes", type=str, default=1) - parser.add_argument("--seq_len", type=int, default=128) - parser.add_argument("--n_steps", type=int, default=100) - parser.add_argument("--batch_size", type=int, required=True) - parser.add_argument("--warmup_steps", type=int, default=1) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--model", type=str, required=True, help="Model") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") + parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") + parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps") + parser.add_argument("--batch_size", type=int, required=True, help="Batch size") + parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() if args.n_processes == "n_gpus": diff --git a/benchmarks/benchmark_inference.py b/benchmarks/benchmark_inference.py index 607ff88..e894bb1 100755 --- a/benchmarks/benchmark_inference.py +++ b/benchmarks/benchmark_inference.py @@ -16,13 +16,13 @@ logger = get_logger() def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="bigscience/bloom") - parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) - parser.add_argument("--torch_dtype", type=str, default="bfloat16") - parser.add_argument("--n_processes", type=str, default=1) - parser.add_argument("--seq_len", type=int, default=2048) - parser.add_argument("--warmup_steps", type=int, default=1) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--model", type=str, required=True, help="Model") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") + parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") + parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length") + parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() if args.n_processes == "n_gpus": diff --git a/benchmarks/benchmark_training.py b/benchmarks/benchmark_training.py index 0853dfc..85061a3 100755 --- a/benchmarks/benchmark_training.py +++ b/benchmarks/benchmark_training.py @@ -15,18 +15,18 @@ logger = get_logger() def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="bigscience/bloom") - parser.add_argument("--device", type=str, default="cpu") - parser.add_argument("--task", type=str, default="cls") - parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) - parser.add_argument("--torch_dtype", type=str, default="bfloat16") - parser.add_argument("--n_processes", type=str, default=1) - parser.add_argument("--seq_len", type=int, default=128) - parser.add_argument("--pre_seq_len", type=int, default=16) - parser.add_argument("--n_steps", type=int, default=10) - parser.add_argument("--batch_size", type=int, required=True) - parser.add_argument("--warmup_steps", type=int, default=1) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--model", type=str, required=True, help="Model") + parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client") + parser.add_argument("--task", type=str, default="cls", help="Training task type") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") + parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") + parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens") + parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps") + parser.add_argument("--batch_size", type=int, required=True, help="Batch size") + parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() assert args.task in ["cls", "causal_lm"] diff --git a/src/petals/cli/run_dht.py b/src/petals/cli/run_dht.py index 2f30516..777d9d0 100644 --- a/src/petals/cli/run_dht.py +++ b/src/petals/cli/run_dht.py @@ -7,8 +7,8 @@ This script may be used for launching lightweight CPU machines serving as bootst This may be eventually merged to the hivemind upstream. """ +import argparse import time -from argparse import ArgumentParser from secrets import token_hex from hivemind.dht import DHT, DHTNode @@ -35,7 +35,7 @@ async def report_status(dht: DHT, node: DHTNode): def main(): - parser = ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--initial_peers", nargs="*", @@ -73,7 +73,9 @@ def main(): help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)", ) parser.add_argument( - "--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls" + "--use_auto_relay", + action="store_true", + help="Look for libp2p relays to become reachable if we are behind NAT/firewall", ) parser.add_argument( "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT" diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index c82ff44..d85c8ac 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -158,7 +158,7 @@ def main(): "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") - parser.add_argument("--adapters", nargs='+', default=(), + parser.add_argument("--adapters", nargs='*', default=(), help="List of pre-loaded LoRA adapters that can be used for inference or training") # fmt:on diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 7772fa6..bf7470a 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -78,7 +78,7 @@ class Server: sender_threads: int = 1, balance_quality: float = 0.75, mean_balance_check_period: float = 120, - mean_block_selection_delay: float = 2.5, + mean_block_selection_delay: float = 5, token: Optional[Union[str, bool]] = None, quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, @@ -201,6 +201,8 @@ class Server: assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() + if num_blocks is not None: + num_blocks = min(num_blocks, self.block_config.num_hidden_layers) if block_indices is not None: try: first_block_index, last_block_index = block_indices.split(":") @@ -295,7 +297,7 @@ class Server: num_blocks = min(num_blocks, self.block_config.num_hidden_layers) logger.info( - f"Server will fill all your GPU memory with {num_blocks} transformer blocks. " + f"Server will fill your GPU memory with {num_blocks} transformer blocks. " f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" ) return num_blocks diff --git a/tests/test.id b/tests/bootstrap.id similarity index 100% rename from tests/test.id rename to tests/bootstrap.id diff --git a/tests/server2.id b/tests/server2.id new file mode 100644 index 0000000000000000000000000000000000000000..261555779ed4194b8a3cfb49a7759755f13b270f GIT binary patch literal 1197 zcmV;e1XBA501~JPFoFc60s#O5f&l>l*q);s;IqHH@*TeGz3p8h%ogpB@p~dfJljpA zmte=Ie-Gw8J^pzE;N;A|TySjx1ITHq&Fw2mHDu=LQ2lYRrY>DWrj4QP<2_0oz!S0g z@)w-t(L+;5N{@a1yYto5*dj`NLP!=o%-#^^PAEF0xtrMu4o6pYD^O7tSrS^2Y=S?b z&bw`eDc91oG@pLQ7mY*9HXPSlvGz~uU+CX0ot^Qp`&O>>Z7IQt%Z;o5=UtuvOr?t= z{q_OT^!9R#u?@xC_CRk!lArRsJn+Zu*k#>yxP|dgnrZX3@2(eUjw}Vy!~PsWNai@M zFE6W)xe}tvWxpx%f!JEtX3v{r0s{d60Rn;n0RXKXf^O;&+`i1Cw+q2qYu`ReXp3vg zBPDa0GoZm0s2P;=2J2?hU#U)}#%w4HD*q%12Qamk8UETGupi+4>vms2$u$w&nggMWk z`irkmy|(rs--zp8M;QL=+{)g6n)UQEATa`gfdKwc-hnoGwdV@~_{9}<5=y;u@eOicSvo%SWl}CRVNK@W zn!y03xCc>(6p@G6r94hYn_D2p9ZIR1AS6xTjxgbu-teQuBfVq#kxkN#5m-|W>I(bU zSc?LIfdJYK-1@S;1xE0$a8+)C7ho~FSe5p`prKwzXe9u06xr=CZP>H*7igmqfuv!a zMR|e0l*4N1>0`FPLft$KHMqLrLnnVTT&CS`l0E6u^hk1(_JNN{6pEjOUAh zx45S~IIH=|Y51!r_PZkwx8SC!?7G$}q4eLiKHJqSP(ke?YQ2kQahF*b_zeM>68_+z zUBpQZ1JHooP4O*Et0bYFFHq#MGvx9t7#V%nBkQO@CPax+IL=r+ z1WD4qu>ygB3(`)Nq}2?OF#$9K-LBo&kUFUC&}Y{Th>!Gqbq=~js!&4INySpskZPrk zq~R3*CSsk94xzj4bp79fApKHVMu2JLokz-%9abz$sl%NlhK~+hDRH6M00Nhd<&%bg z5t-KM&saaR*Hs2oAHq)S_KCV(;3cI<)V3lCxX{Xp0)c@5#3qS(2@Ge+yD|^y#pT(_ zHUPiXqU^@iKc0jJ(|7+UE8$9Ujgl7jofv8!K3%fbpE1yr47Mzvdk Date: Wed, 9 Aug 2023 16:50:02 +0400 Subject: [PATCH 165/168] benchmarks: Aggregate speed among workers, set default dtype torch32 (#454) --- benchmarks/benchmark_forward.py | 12 ++++++++---- benchmarks/benchmark_inference.py | 12 ++++++++---- benchmarks/benchmark_training.py | 12 ++++++++---- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/benchmarks/benchmark_forward.py b/benchmarks/benchmark_forward.py index bf547ec..5ad678c 100755 --- a/benchmarks/benchmark_forward.py +++ b/benchmarks/benchmark_forward.py @@ -18,7 +18,7 @@ def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model", type=str, required=True, help="Model") parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") - parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype") parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps") @@ -31,15 +31,19 @@ def main(): else: args.n_processes = int(args.n_processes) - processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)] + pipe_recv, pipe_send = mp.Pipe(duplex=False) + processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)] for proc in processes: proc.start() for proc in processes: proc.join() + speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)]) + logger.info(f"Final result: {speed=:.2f}") + @torch.inference_mode() -def benchmark_forward(process_idx, args): +def benchmark_forward(process_idx, args, result_pipe): model = AutoDistributedModel.from_pretrained( args.model, initial_peers=args.initial_peers, @@ -64,7 +68,7 @@ def benchmark_forward(process_idx, args): speed = input_ids.numel() / np.mean(step_times) logger.info(f"{process_idx=} {step=} {speed=:.2f}") - logger.info(f"Final result: {process_idx=} {speed=:.2f}") + result_pipe.send(speed) if __name__ == "__main__": diff --git a/benchmarks/benchmark_inference.py b/benchmarks/benchmark_inference.py index e894bb1..202dc6d 100755 --- a/benchmarks/benchmark_inference.py +++ b/benchmarks/benchmark_inference.py @@ -19,7 +19,7 @@ def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model", type=str, required=True, help="Model") parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") - parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype") parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length") parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") @@ -30,15 +30,19 @@ def main(): else: args.n_processes = int(args.n_processes) - processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)] + pipe_recv, pipe_send = mp.Pipe(duplex=False) + processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)] for proc in processes: proc.start() for proc in processes: proc.join() + speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)]) + logger.info(f"Final result: {speed=:.2f}") + @torch.inference_mode() -def benchmark_inference(process_idx, args): +def benchmark_inference(process_idx, args, result_pipe): tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway @@ -61,7 +65,7 @@ def benchmark_inference(process_idx, args): speed = 1 / np.mean(step_times) logger.info(f"{process_idx=} {step=} {speed=:.2f}") - logger.info(f"Final result: {process_idx=} {speed=:.2f}") + result_pipe.send(speed) if __name__ == "__main__": diff --git a/benchmarks/benchmark_training.py b/benchmarks/benchmark_training.py index 85061a3..f542907 100755 --- a/benchmarks/benchmark_training.py +++ b/benchmarks/benchmark_training.py @@ -20,7 +20,7 @@ def main(): parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client") parser.add_argument("--task", type=str, default="cls", help="Training task type") parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") - parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype") parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens") @@ -36,14 +36,18 @@ def main(): else: args.n_processes = int(args.n_processes) - processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)] + pipe_recv, pipe_send = mp.Pipe(duplex=False) + processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)] for proc in processes: proc.start() for proc in processes: proc.join() + fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0) + logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}") -def benchmark_training(process_idx, args): + +def benchmark_training(process_idx, args, result_pipe): if args.task == "cls": model = AutoDistributedModelForSequenceClassification.from_pretrained( args.model, @@ -96,7 +100,7 @@ def benchmark_training(process_idx, args): bwd_speed = input_ids.numel() / np.mean(bwd_times) logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}") - logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}") + result_pipe.send((fwd_speed, bwd_speed)) if __name__ == "__main__": From 55eb36ef4829d61315361049edf85753c631dab8 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 9 Aug 2023 21:59:56 +0300 Subject: [PATCH 166/168] Fix missing torch.cuda.synchronize for computing throughput (#456) --- src/petals/server/throughput.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index d977611..2806183 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -51,7 +51,7 @@ def get_server_throughput( if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, "throughput.lock") - cache_path = Path(cache_dir, "throughput_v4.json") + cache_path = Path(cache_dir, "throughput_v5.json") # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) @@ -196,6 +196,7 @@ def measure_compute_rps( n_steps: int, inference: bool, ) -> float: + device = torch.device(device) if not tensor_parallel_devices: tensor_parallel_devices = (device,) with torch.inference_mode(): @@ -204,13 +205,17 @@ def measure_compute_rps( cache = None elapsed = 0 - for step in range(n_steps + 1): - dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) + dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) + _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time + if device.type == "cuda": + torch.cuda.synchronize(device) - start_time = time.perf_counter() + start_time = time.perf_counter() + for step in range(n_steps): _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) - if step >= 1: # Skip the 1st step to exclude the initialization time - elapsed += time.perf_counter() - start_time + if device.type == "cuda": + torch.cuda.synchronize(device) + elapsed = time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed devices_repr = get_device_name(device) From 056f22515ab89cb12a8e88445b46aad1c5d2777b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 11 Aug 2023 09:24:33 +0400 Subject: [PATCH 167/168] Prioritize short inference, unmerge pools for long inference (#458) Right now, long inference requests may occupy Runtime for a few seconds without giving it away to process short (most latency-sensitive requests). This PR fixes it by disallowing the merged pool for long requests and prioritizing the short ones. --- src/petals/server/block_functions.py | 44 ++++++++++++++++++--------- src/petals/server/handler.py | 4 +++ src/petals/server/server.py | 1 + src/petals/server/task_prioritizer.py | 9 +++--- tests/test_block_exact_match.py | 41 ++++++++++++++----------- tests/test_remote_sequential.py | 2 +- 6 files changed, 64 insertions(+), 37 deletions(-) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 9208deb..c1f1d93 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -16,8 +16,15 @@ from petals.server.backend import TransformerBackend from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import TaskPrioritizerBase +from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy +# We prioritize short inference requests and make them use a *merged* inference pool, +# so they are processed without interruptions and extra overheads +# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward +MAX_SHORT_INFERENCE_TOKENS = 128 +MAX_NF4_SHORT_INFERENCE_TOKENS = 1 + async def run_rpc_forward( *flat_tensors: torch.Tensor, @@ -127,9 +134,11 @@ async def iterate_rpc_inference( active_adapter: Optional[str], input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]], cache_handles: Sequence[Sequence[Handle]], + *, max_length: int, prioritizer: TaskPrioritizerBase, points: int, + quant_type: QuantType, ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: assert len(cache_handles) == len(requested_backends) @@ -138,6 +147,7 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) + batch_size, length_increment, _ = hidden_states.shape # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) @@ -154,34 +164,40 @@ async def iterate_rpc_inference( if not (len(requested_backends) == len(prompts)): raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") - length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) if prefix_length + length_increment > max_length: raise ValueError( f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" f" exceeds pre-allocated maximum {max_length}" ) + merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS + can_merge_pools = batch_size * length_increment <= merge_max_tokens priority = prioritizer.prioritize( hidden_states, hypo_ids, points=point_per_piece, requested_uids=requested_uids, - type="inference", - ) - - inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) - for uid, handles in zip(requested_uids, cache_handles) + type="short_inference" if can_merge_pools else "inference", ) - if hidden_states.numel() == 0: - pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - else: + # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache. + if hidden_states.numel() > 0: assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" - (hidden_states,) = await requested_backends[0].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority - ) + if can_merge_pools: + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) + for uid, handles in zip(requested_uids, cache_handles) + ) + (hidden_states,) = await requested_backends[0].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + ) + else: + for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts): + inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),) + (hidden_states,) = await backend.inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, prompt, priority=priority + ) # serialize and send last layer outputs output_tensors = [ diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index b9be294..00df0d5 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -34,6 +34,7 @@ from petals.server.backend import TransformerBackend from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward from petals.server.memory_cache import Handle from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase +from petals.utils.convert_block import QuantType logger = get_logger(__name__) @@ -71,6 +72,7 @@ class TransformerConnectionHandler(ConnectionHandler): session_timeout: float, step_timeout: float, task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(), + quant_type: QuantType, ): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): @@ -88,6 +90,7 @@ class TransformerConnectionHandler(ConnectionHandler): self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout self._prioritizer = task_prioritizer + self.quant_type = quant_type async def add_p2p_handlers(self, *args, **kwargs) -> None: if self._listener_task is None: @@ -176,6 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler): max_length=max_length, prioritizer=self._prioritizer, points=points, + quant_type=self.quant_type, ): if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index bf7470a..405dd9b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -560,6 +560,7 @@ class ModuleContainer(threading.Thread): request_timeout=request_timeout, session_timeout=session_timeout, step_timeout=step_timeout, + quant_type=QuantType[server_info.quant_type.upper()], ) for i in range(num_handlers) ] diff --git a/src/petals/server/task_prioritizer.py b/src/petals/server/task_prioritizer.py index 6490fc5..4a575b1 100644 --- a/src/petals/server/task_prioritizer.py +++ b/src/petals/server/task_prioritizer.py @@ -13,9 +13,10 @@ class TaskPrioritizerBase(ABC): class DummyTaskPrioritizer(TaskPrioritizerBase): - """Simple implementation of TaskPrioritizer which gives constant zero priority for every task""" - def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: + # Inference steps (especially short ones) go first since they are more latency-sensitive + if kwargs.get("type") == "short_inference": + return 1.0 if kwargs.get("type") == "inference": - return 1.0 # inference steps go first since they are more latency-sensitive - return 2.0 # forward, backward + return 2.0 + return 3.0 # Forward, backward diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index d98918b..80c695f 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -4,6 +4,7 @@ import pytest import torch from petals import AutoDistributedConfig, RemoteSequential +from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS from petals.server.from_pretrained import load_pretrained_block from test_utils import * @@ -13,26 +14,30 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) - for block_index in random.sample(range(config.num_hidden_layers), 3): - remote_block = remote_sequential[block_index] + block_index = random.randint(0, config.num_hidden_layers - 1) + remote_block = remote_sequential[block_index] - inputs = torch.randn(1, 8, config.hidden_size) - outputs_forward = remote_block(inputs) + inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size) + outputs_forward = remote_block(inputs) - outputs_inference = [] - with torch.inference_mode(): - with remote_block.inference_session(max_length=inputs.shape[1]) as sess: - for i in range(inputs.shape[1]): - outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + outputs_inference = [] + with torch.inference_mode(): + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: + # Test long inference (unmerged inference pools) + outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :])) - # test that max length is respected - with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: - sess.step(inputs[:, -1:, :]) - assert "Maximum length exceeded" in repr(exc_info.value) - outputs_inference = torch.cat(outputs_inference, dim=1) + # Test short inference (merged inference pools) + for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]): + outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) - ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) - (outputs_local,) = ref_block(inputs) + # test that max length is respected + with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: + sess.step(inputs[:, -1:, :]) + assert "Maximum length exceeded" in repr(exc_info.value) + outputs_inference = torch.cat(outputs_inference, dim=1) - assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) - assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) + ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) + (outputs_local,) = ref_block(inputs) + + assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) + assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 9189e68..30698c5 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -40,7 +40,7 @@ def test_remote_sequential(): assert hidden.shape == test_inputs.shape assert hidden.requires_grad second_half_outputs = second_half(hidden) - assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4) + assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3) (second_half_outputs * grad_proj).sum().backward() assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2) From 722c4dc49651f4d3f22d3c790e502507e7527a12 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 11 Aug 2023 09:34:05 +0400 Subject: [PATCH 168/168] Bump version to 2.0.1.post2 (#459) --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index c9f8223..1ef8609 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.0.1.post1" +__version__ = "2.0.1.post2" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

Network