diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 05cebdd..74b731d 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -48,7 +48,6 @@ jobs: export MODEL_NAME="${{ matrix.model }}" export REF_NAME="${{ matrix.model }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" - export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) @@ -61,27 +60,25 @@ jobs: until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ - --mean_balance_check_period 10 \ - --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log & + export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \ + --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS" + export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" + + $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log & SERVER1_PID=$! # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there sleep 10 # wait for the 1st server to choose blocks - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \ - --identity_path tests/server2.id \ - --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log & + $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log & SERVER2_PID=$! - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \ - --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ - --initial_peers $INITIAL_PEERS --throughput auto &> server3.log & + $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \ + --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log & SERVER3_PID=$! # ^-- chunking test - python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \ - --initial_peers $INITIAL_PEERS --throughput auto &> server4.log & + $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log & SERVER4_PID=$! # ^-- tensor parallelism test (not compatible with adapters yet) @@ -102,6 +99,9 @@ jobs: export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES + # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely + export PETALS_MAX_RETRIES=10 + pytest tests --durations=0 --durations-min=1.0 -v # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) @@ -118,4 +118,3 @@ jobs: # [Step 4] Clean up kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID - echo "Done!" diff --git a/README.md b/README.md index 6987489..63449ae 100644 --- a/README.md +++ b/README.md @@ -8,14 +8,14 @@

-Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM # Choose any model available at https://health.petals.dev -model_name = "petals-team/StableBeluga2" +model_name = "petals-team/StableBeluga2" # This one is fine-tuned Llama 2 (70B) # Connect to a distributed network hosting model layers tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -31,9 +31,9 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🚀  Try now in Colab

-🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). +🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. -🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. +🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! @@ -81,9 +81,8 @@ python3 -m petals.cli.run_server petals-team/StableBeluga2 ## How does it work? -- Petals runs large language models like [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at **up to 6 steps/sec** for **Llama 2** (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. -- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch. +- You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Single‑batch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps. +- You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.

@@ -113,99 +112,15 @@ Advanced guides: - Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) -## Benchmarks - -The benchmarks below are for BLOOM-176B: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NetworkSingle-batch inference
(steps/s)
Parallel forward
(tokens/s)
BandwidthRound-trip
latency
Sequence lengthBatch size
1282048164
Offloading, max. possible speed on 1x A100 1
256 Gbit/s0.180.182.7170.3
128 Gbit/s0.090.092.4152.8
Petals on 14 heterogeneous servers across Europe and North America 2
Real world0.830.7932.6179.4
Petals on 3 servers, with one A100 each 3
1 Gbit/s< 5 ms1.711.5470.0253.6
100 Mbit/s< 5 ms1.661.4956.4182.0
100 Mbit/s100 ms1.231.1119.7112.2
- -1 **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch. - -2 **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 100–1000 Mbit/s. 4 servers operate from under firewalls. - -3 **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU. - -We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf). - -## 🛠️ Contributing +### Benchmarks + +Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf). + +### 🛠️ Contributing Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing. -## 📜 Citation +### 📜 Citation Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel. [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188) diff --git a/setup.cfg b/setup.cfg index c8dbc9a..ef35f84 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = accelerate>=0.22.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py + transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet hivemind==1.1.10.post2 @@ -47,7 +47,7 @@ install_requires = cpufeature>=0.2.0; platform_machine == "x86_64" packaging>=20.9 sentencepiece>=0.1.99 - peft>=0.5.0 + peft==0.5.0 safetensors>=0.3.1 Dijkstar>=2.6.0 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 27076ba..0e2be34 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,13 +17,13 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.1.0" +__version__ = "2.3.0.dev1" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( - version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0") - ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0" + version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0") + ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0" assert version.parse("1.1.10") <= version.parse( hivemind.__version__ ), "Please install a proper hivemind version: pip install hivemind>=1.1.10" diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 94f5c2e..5208438 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -70,17 +70,17 @@ def main(): parser.add_argument('--inference_max_length', type=int, default=None, help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all operations (in total tokens)') parser.add_argument('--max_batch_size', type=int, default=None, help='The total number of tokens in the same batch will not exceed this value. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') parser.add_argument('--attn_cache_tokens', type=int, default=None, help='The number of past attention key/value pairs that will be stored between inference steps. ' - 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') diff --git a/src/petals/client/config.py b/src/petals/client/config.py index e255024..a2f8f42 100644 --- a/src/petals/client/config.py +++ b/src/petals/client/config.py @@ -1,10 +1,14 @@ import dataclasses +import os from typing import Optional, Sequence, Union from hivemind import PeerID from petals.constants import PUBLIC_INITIAL_PEERS +_max_retries = os.getenv("PETALS_MAX_RETRIES") +DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None + @dataclasses.dataclass class ClientConfig: @@ -21,7 +25,7 @@ class ClientConfig: request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds - max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf) min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds diff --git a/src/petals/client/from_pretrained.py b/src/petals/client/from_pretrained.py index f2c88d2..4b9d8e5 100644 --- a/src/petals/client/from_pretrained.py +++ b/src/petals/client/from_pretrained.py @@ -6,7 +6,6 @@ import tempfile from contextvars import ContextVar from typing import List, Optional, Tuple, Union -import torch from hivemind.utils.logging import get_logger from transformers import BloomPreTrainedModel, modeling_utils @@ -22,21 +21,14 @@ class FromPretrainedMixin: model_name_or_path: Union[str, os.PathLike, None], *args, low_cpu_mem_usage: Optional[bool] = None, - torch_dtype: Optional[Union[str, torch.dtype]] = None, **kwargs, ): model_name_or_path = get_compatible_model_repo(model_name_or_path) if low_cpu_mem_usage is None: low_cpu_mem_usage = True - if torch_dtype is None: - # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, - # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. - torch_dtype = "auto" with ignore_keys(cls._keys_to_ignore_on_load_unexpected): - return super().from_pretrained( - model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs - ) + return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( "low_cpu_mem_usage(`bool`, *optional*)", diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index f6195d8..8789be7 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -305,11 +305,21 @@ class InferenceSession: else: assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks + assert prompts.shape[1] in (inputs.shape[0], 1) + assert prompts.shape[2] <= inputs.shape[1] + assert prompts.shape[3] == inputs.shape[2] + + if hypo_ids is None or is_dummy(hypo_ids): + hypo_ids = DUMMY_INT64 + else: + assert len(hypo_ids) == len(inputs) + assert hypo_ids.dtype == torch.int64 inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() + hypo_ids = hypo_ids.cpu() step_id = str(uuid.uuid4()) n_input_tokens = inputs.shape[1] diff --git a/src/petals/client/lm_head.py b/src/petals/client/lm_head.py index cbea89d..bc0e293 100644 --- a/src/petals/client/lm_head.py +++ b/src/petals/client/lm_head.py @@ -1,8 +1,7 @@ import dataclasses import platform -from typing import Optional, Union +from typing import Union -import psutil import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -68,11 +67,10 @@ class LMHead(nn.Module): assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" if not self._bf16_warning_shown: - if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: - logger.warning( - "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " - "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" - ) + logger.warning( + "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " + "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" + ) self._bf16_warning_shown = True hidden_states = hidden_states.float() diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index bce6712..2c9137b 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -1,17 +1,15 @@ import dataclasses import time -from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar +from typing import Iterable, List, Optional, Tuple from hivemind import get_logger from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState +from petals.utils.dht import compute_spans logger = get_logger(__name__) -T = TypeVar("T") - - @dataclasses.dataclass class RemoteSequenceInfo: """ @@ -30,7 +28,7 @@ class RemoteSequenceInfo: last_updated_time: Optional[float] @classmethod - def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T: + def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo": block_uids = tuple(block_uids) empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) empty_spans = tuple([] for _ in range(len(block_uids))) @@ -39,7 +37,7 @@ class RemoteSequenceInfo: def __getitem__(self, ix: slice): assert isinstance(ix, slice) block_uids, block_infos = self.block_uids[ix], self.block_infos[ix] - spans_by_priority, spans_containing_block = self.compute_spans(block_infos) + spans_by_priority, spans_containing_block = self._sort_spans(block_infos) return RemoteSequenceInfo( block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time ) @@ -47,60 +45,23 @@ class RemoteSequenceInfo: def __len__(self): return len(self.block_uids) - def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]): + def update_(self, new_block_infos: List[RemoteModuleInfo]): assert len(new_block_infos) == len(self.block_uids) for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): - if info is None: - logger.debug(f"Found no block info for block {uid}") - continue - if not isinstance(info, RemoteModuleInfo): - logger.warning(f"Unexpected dht entry type for {uid}: {info}") - continue - if not info.servers: - logger.debug(f"Found no active peers for block {uid}") - continue - if info.uid != uid: - logger.warning(f"The DHT entry for {uid} actually points to {info.uid}") - continue + assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}" self.block_infos[block_index].servers = info.servers - self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) + self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos) self.last_updated_time = time.perf_counter() @staticmethod - def compute_spans(block_infos: Sequence[RemoteModuleInfo]): - closed_spans = [] - active_spans = {} - for block_index, info in enumerate(block_infos): - if info is not None: - for peer_id, server_info in info.servers.items(): - if server_info.state != ServerState.ONLINE: - continue - if peer_id not in active_spans: - active_spans[peer_id] = RemoteSpanInfo( - peer_id=peer_id, - start=block_index, - end=block_index + 1, - server_info=server_info, - ) - else: # peer_id in active_spans - active_spans[peer_id].end = block_index + 1 - - for peer_id in list(active_spans.keys()): - if ( - info is None - or peer_id not in info.servers - or info.servers[peer_id].state != ServerState.ONLINE - or block_index == len(block_infos) - 1 - ): - closed_spans.append(active_spans.pop(peer_id)) - assert not active_spans, f"spans: {active_spans}" - - closed_spans.sort(key=lambda span: span.length, reverse=True) + def _sort_spans(block_infos: List[RemoteModuleInfo]): + spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values()) + spans_by_priority.sort(key=lambda span: span.length, reverse=True) - spans_containing_block = tuple(list() for _ in range(len(block_infos))) - for span in closed_spans: + spans_containing_block = tuple([] for _ in range(len(block_infos))) + for span in spans_by_priority: for block_index in range(span.start, span.end): spans_containing_block[block_index].append(span) - return closed_spans, spans_containing_block + return spans_by_priority, spans_containing_block diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 884a37b..c9cc94b 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -117,7 +117,6 @@ class RemoteSequenceManager: if state.sequence_info.last_updated_time is not None: assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch - self._need_latest_infos = True @staticmethod def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: @@ -346,9 +345,6 @@ class RemoteSequenceManager: ) for block_info in new_block_infos: - if not block_info: - continue - # Apply allow and block lists block_info.servers = { peer_id: server_info diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 7c86f14..9cbbf76 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -11,18 +11,15 @@ UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" -class ServerState(Enum): - OFFLINE = 0 - JOINING = 1 - ONLINE = 2 - - -RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) +def parse_uid(uid: ModuleUID) -> Tuple[str, int]: + assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs" + dht_prefix, index = uid.split(UID_DELIMITER) + return dht_prefix, int(index) @pydantic.dataclasses.dataclass class ModelInfo: - num_blocks: int + num_blocks: pydantic.conint(ge=1, strict=True) repository: Optional[str] = None def to_dict(self) -> dict: @@ -33,11 +30,23 @@ class ModelInfo: return cls(**source) +class ServerState(Enum): + OFFLINE = 0 + JOINING = 1 + ONLINE = 2 + + +RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + + @pydantic.dataclasses.dataclass class ServerInfo: state: ServerState throughput: RPS + start_block: Optional[pydantic.conint(ge=0, strict=True)] = None + end_block: Optional[pydantic.conint(ge=0, strict=True)] = None + public_name: Optional[str] = None version: Optional[str] = None @@ -83,9 +92,17 @@ class RemoteSpanInfo: server_info: ServerInfo @property - def length(self): + def length(self) -> int: return self.end - self.start + @property + def state(self) -> ServerState: + return self.server_info.state + + @property + def throughput(self) -> float: + return self.server_info.throughput + RPCInfo = Dict[str, Any] diff --git a/src/petals/models/falcon/config.py b/src/petals/models/falcon/config.py index a1ae5e9..9fadede 100644 --- a/src/petals/models/falcon/config.py +++ b/src/petals/models/falcon/config.py @@ -31,6 +31,9 @@ class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): + if "180B" in model_name_or_path.upper(): + logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license") + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py index 32c0b6f..296214d 100644 --- a/src/petals/models/falcon/model.py +++ b/src/petals/models/falcon/model.py @@ -47,6 +47,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -68,6 +69,9 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" + assert ( + position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() + ), f"Non-consecutive position_ids are not supported, {position_ids=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 55f659a..a8d433d 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -3,13 +3,219 @@ LLaMA intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py See commit history for authorship. """ +import math from typing import Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + repeat_kv, + rotate_half, +) +from petals.utils.cuda_graphs import make_inference_graphed_callable + + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class OptimizedLlamaAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._rotary_graph = None + + def _optimized_apply_rotary(self, query_states, key_states, cos, sin): + if self._rotary_graph is None: + self._rotary_graph = make_inference_graphed_callable( + apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin) + ) + return self._rotary_graph(query_states, key_states, cos, sin) -class WrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert not output_attentions + assert position_ids is None + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[:, :, kv_seq_len - q_len :] + sin = sin[:, :, kv_seq_len - q_len :] + + if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = OptimizedLlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.pre_attn_graph = None + self.post_attn_graph = None + + def _optimized_input_layernorm(self, hidden_states): + if self.pre_attn_graph is None: + self.pre_attn_graph = make_inference_graphed_callable( + self.input_layernorm.forward, sample_args=(hidden_states,) + ) + return self.pre_attn_graph(hidden_states) + + def _optimized_output_layernorm(self, hidden_states): + if self.post_attn_graph is None: + self.post_attn_graph = make_inference_graphed_callable( + self.post_attention_layernorm.forward, sample_args=(hidden_states,) + ) + return self.post_attn_graph(hidden_states) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + hidden_states = self._optimized_input_layernorm(hidden_states) + else: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + hidden_states = self._optimized_output_layernorm(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, @@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer): seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) - if position_ids is None: - device = hidden_states.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + assert position_ids is None # embed positions if attention_mask is None: diff --git a/src/petals/server/block_selection.py b/src/petals/server/block_selection.py index cc050d4..441c0cd 100644 --- a/src/petals/server/block_selection.py +++ b/src/petals/server/block_selection.py @@ -1,54 +1,23 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List import numpy as np from hivemind import PeerID, get_logger -from petals.data_structures import RemoteModuleInfo, ServerState - -__all__ = ["choose_best_blocks", "should_choose_other_blocks"] +from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState +from petals.utils.dht import compute_spans logger = get_logger(__name__) -@dataclass -class Span: - start: int - end: int - throughput: float - state: ServerState - - @property - def length(self): - return self.end - self.start - - def move_to(self, new_start: int) -> None: - self.start, self.end = new_start, new_start + self.length - - -def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]: - spans = {} - throughputs = np.zeros(len(module_infos)) - for block, module in enumerate(module_infos): - if module is None: - continue - - # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers. - # If the order were not defined, we would get slightly different values due to floating point errors, - # which may cause excess block replacements. - for peer_id, server in sorted(module.servers.items()): - if server.state == ServerState.OFFLINE: - continue +def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray: + # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers. + # If the order were not defined, we would get slightly different values due to floating point errors, + # which may cause excess block replacements. - if peer_id in spans: - spans[peer_id].start = min(spans[peer_id].start, block) - spans[peer_id].end = max(spans[peer_id].start, block + 1) - else: - spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state) - - throughputs[block] += server.throughput - - return spans, throughputs + throughputs = np.zeros(total_blocks) + for span in sorted(spans.values(), key=lambda span: span.peer_id): + throughputs[span.start : span.end] += span.throughput + return throughputs def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: @@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: return min(options)[-1] -def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: - _, throughputs = compute_spans(module_infos) +def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]: + spans = compute_spans(module_infos, min_state=ServerState.JOINING) + throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) + start = _choose_best_start(throughputs, num_blocks) return list(range(start, start + num_blocks)) +def _move_span(span: RemoteSpanInfo, new_start: int): + span.start, span.end = new_start, new_start + span.length + + def should_choose_other_blocks( - local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float + local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float ) -> bool: if balance_quality > 1.0: return True # Forces rebalancing on each check (may be used for debugging purposes) - spans, throughputs = compute_spans(module_infos) + spans = compute_spans(module_infos, min_state=ServerState.JOINING) + throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) initial_throughput = throughputs.min() eps = 1e-3 @@ -88,7 +64,7 @@ def should_choose_other_blocks( return False # This server is on its best place already throughputs[local_span.start : local_span.end] += local_span.throughput * eps - local_span.move_to(new_start) + _move_span(local_span, new_start) throughputs[local_span.start : local_span.end] += local_span.throughput moved = True @@ -105,7 +81,7 @@ def should_choose_other_blocks( throughputs[span.start : span.end] += span.throughput * eps if span.start != new_start: - span.move_to(new_start) + _move_span(span, new_start) moved = True throughputs[span.start : span.end] += span.throughput diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e8e3d59..5769adb 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -24,7 +24,7 @@ from transformers import PretrainedConfig import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype @@ -204,7 +204,7 @@ class Server: # For attention cache in GPU or RAM if attn_cache_tokens is None: - attn_cache_tokens = 32768 if is_multiquery_attn else 8192 + attn_cache_tokens = 16384 if is_multiquery_attn else 4096 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype) @@ -221,11 +221,10 @@ class Server: num_blocks = min(num_blocks, self.block_config.num_hidden_layers) if block_indices is not None: try: - first_block_index, last_block_index = block_indices.split(":") - first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index))) + start_block, end_block = [int(index.strip()) for index in block_indices.split(":")] except Exception as e: raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)") - block_indices = range(first_block_index, last_block_index) + block_indices = range(start_block, end_block) num_blocks = len(block_indices) self.strict_block_indices, self.num_blocks = block_indices, num_blocks @@ -704,11 +703,16 @@ class ModuleAnnouncerThread(threading.Thread): self.expiration = expiration self.trigger = threading.Event() + self.dht_prefix = parse_uid(module_uids[0])[0] + block_indices = [parse_uid(uid)[1] for uid in module_uids] + self.server_info.start_block = min(block_indices) + self.server_info.end_block = max(block_indices) + 1 + self.max_pinged = max_pinged - self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0] - block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] - start_block, end_block = min(block_indices), max(block_indices) + 1 - self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] + self.next_uids = [ + f"{self.dht_prefix}{UID_DELIMITER}{i}" + for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1) + ] self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: @@ -756,12 +760,11 @@ class ModuleAnnouncerThread(threading.Thread): def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True) - middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers} + middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers} pinged_servers = set(sample_up_to(middle_servers, self.max_pinged)) pinged_servers.discard(self.dht.peer_id) - if module_infos[-1] is not None: - # Sample servers hosting the block after the last one (most likely continuations) separately - pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) + # Sample servers hosting the block after the last one (most likely continuations) separately + pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) self.ping_aggregator.ping(list(pinged_servers)) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index bf71f44..c42bdb9 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -56,7 +56,7 @@ def get_server_throughput( # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) - with open(lock_path, "wb") as lock_fd: + with open(lock_path, "wb+") as lock_fd: logger.info("Loading throughput info") fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # The OS will release the lock when lock_fd is closed or the process is killed diff --git a/src/petals/utils/cuda_graphs.py b/src/petals/utils/cuda_graphs.py new file mode 100644 index 0000000..216ecf1 --- /dev/null +++ b/src/petals/utils/cuda_graphs.py @@ -0,0 +1,76 @@ +import torch +from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten + + +def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3): + """Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass""" + assert not isinstance(callable, torch.nn.Module) + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." + ) + + flatten_arg, _ = _tree_flatten(sample_args) + flatten_sample_args = tuple(flatten_arg) + assert all( + isinstance(arg, torch.Tensor) for arg in flatten_arg + ), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed." + + len_user_args = len(sample_args) + static_input_surface = flatten_sample_args + + graph = torch.cuda.CUDAGraph() + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(num_warmup_iters): + outputs, _ = _tree_flatten(callable(*sample_args)) + del outputs + torch.cuda.current_stream().wait_stream(s) + + # Capture forward graph + with torch.cuda.graph(graph): + outputs = callable(*sample_args) + + flatten_outputs, output_unflatten_spec = _tree_flatten(outputs) + static_outputs = tuple(flatten_outputs) + + def make_graphed_function( + graph, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + ): + def replay_graph(*inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + flatten_user_args, _ = _tree_flatten(user_args) + out = replay_graph(*flatten_user_args) + return _tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callable + graphed = make_graphed_function( + graph, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + ) + return graphed diff --git a/src/petals/utils/dht.py b/src/petals/utils/dht.py index 0710f60..4faf74a 100644 --- a/src/petals/utils/dht.py +++ b/src/petals/utils/dht.py @@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo +from petals.data_structures import ( + CHAIN_DELIMITER, + UID_DELIMITER, + ModuleUID, + RemoteModuleInfo, + RemoteSpanInfo, + ServerInfo, + ServerState, + parse_uid, +) logger = get_logger(__name__) @@ -70,7 +79,7 @@ def get_remote_module_infos( *, latest: bool = False, return_future: bool = False, -) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: +) -> Union[List[RemoteModuleInfo], MPFuture]: return dht.run_coroutine( partial( _get_remote_module_infos, @@ -90,7 +99,7 @@ async def _get_remote_module_infos( active_adapter: Optional[str], expiration_time: Optional[DHTExpiration], latest: bool, -) -> List[Optional[RemoteModuleInfo]]: +) -> List[RemoteModuleInfo]: if latest: assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" expiration_time = math.inf @@ -99,14 +108,14 @@ async def _get_remote_module_infos( num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) - modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids) - for i, uid in enumerate(uids): - metadata = found[uid] + modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids] + for module_info in modules: + metadata = found[module_info.uid] if metadata is None or not isinstance(metadata.value, dict): if metadata is not None: - logger.warning(f"Incorrect metadata for {uid}: {metadata}") + logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}") continue - servers = {} + for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) @@ -116,9 +125,29 @@ async def _get_remote_module_infos( logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") continue - servers[peer_id] = server_info + module_info.servers[peer_id] = server_info except (TypeError, ValueError) as e: - logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") - if servers: - modules[i] = RemoteModuleInfo(uid, servers) + logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}") return modules + + +def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]: + block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0 + num_blocks = len(module_infos) + + spans = {} + for block_idx, module_info in enumerate(module_infos): + for peer_id, server_info in sorted(module_info.servers.items()): + if server_info.state.value < min_state.value: + continue + + if peer_id not in spans or spans[peer_id].state.value < server_info.state.value: + spans[peer_id] = RemoteSpanInfo( + peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info + ) + if server_info.start_block is not None and server_info.end_block is not None: + spans[peer_id].start = max(server_info.start_block - block_offset, 0) + spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks) + elif spans[peer_id].state == server_info.state: + spans[peer_id].end = max(spans[peer_id].end, block_idx + 1) + return spans diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py index a26a0f5..5de47c8 100644 --- a/src/petals/utils/disk_cache.py +++ b/src/petals/utils/disk_cache.py @@ -22,7 +22,7 @@ def _blocks_lock(cache_dir: Optional[str], mode: int): lock_path = Path(cache_dir, BLOCKS_LOCK_FILE) os.makedirs(lock_path.parent, exist_ok=True) - with open(lock_path, "wb") as lock_fd: + with open(lock_path, "wb+") as lock_fd: fcntl.flock(lock_fd.fileno(), mode) # The OS will release the lock when lock_fd is closed or the process is killed yield diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 5baa1a2..84cbfff 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import pytest import torch from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block @@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): return state -@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") +class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + batch_size, seq_length, _ = hidden_states.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + past_key_value = layer_past + if past_key_value is not None: + past_key_values_length = past_key_value[0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) + + if position_ids is None: + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = LlamaModel._prepare_decoder_attention_mask( + None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_llama_to_bloom( + present_key_value, batch_size, seq_length_with_past + ) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_llama( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + key_states = key_states.permute(0, 2, 1) + key_states = key_states.view( + batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + value_states = value_states.view(*key_states.shape) + return (key_states, value_states) + + def _reorder_cache_from_llama_to_bloom( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + value_states = value_states.view( + batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + key_states = key_states.view(*value_states.shape) + key_states = key_states.permute(0, 2, 1) + return (key_states, value_states) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) @pytest.mark.forked -def test_falcon(device): +def test_optimized_block(device): if device == "cuda:0" and not torch.cuda.is_available(): pytest.skip("CUDA tests can be run only in CUDA-enabled setups") @@ -108,11 +190,17 @@ def test_falcon(device): quant_type = QuantType.NONE block = config.block_class(config).to(dtype) - block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + + if config.model_type == "falcon": + unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) + elif config.model_type == "llama": + unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype) + else: + pytest.skip(f"This test is not applicable to {config.model_type} models") - unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) unopt_block = convert_block( - unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True ) unopt_block.load_state_dict(block.state_dict())