Merge remote-tracking branch 'origin/main' into forward_kwargs

# Conflicts:
#	src/petals/__init__.py
#	src/petals/client/inference_session.py
pull/467/head
Your Name 5 months ago
commit 3195579620

@ -48,7 +48,6 @@ jobs:
export MODEL_NAME="${{ matrix.model }}" export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}" export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
@ -61,27 +60,25 @@ jobs:
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
--mean_balance_check_period 10 \ --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log & export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
SERVER1_PID=$! SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for the 1st server to choose blocks sleep 10 # wait for the 1st server to choose blocks
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \ $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
--identity_path tests/server2.id \
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
SERVER2_PID=$! SERVER2_PID=$!
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \ $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
SERVER3_PID=$! SERVER3_PID=$!
# ^-- chunking test # ^-- chunking test
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \ $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
SERVER4_PID=$! SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet) # ^-- tensor parallelism test (not compatible with adapters yet)
@ -102,6 +99,9 @@ jobs:
export no_proxy=* export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
# Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely
export PETALS_MAX_RETRIES=10
pytest tests --durations=0 --durations-min=1.0 -v pytest tests --durations=0 --durations-min=1.0 -v
# [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
@ -118,4 +118,3 @@ jobs:
# [Step 4] Clean up # [Step 4] Clean up
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
echo "Done!"

@ -8,14 +8,14 @@
<br> <br>
</p> </p>
Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab: Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
```python ```python
from transformers import AutoTokenizer from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM from petals import AutoDistributedModelForCausalLM
# Choose any model available at https://health.petals.dev # Choose any model available at https://health.petals.dev
model_name = "petals-team/StableBeluga2" model_name = "petals-team/StableBeluga2" # This one is fine-tuned Llama 2 (70B)
# Connect to a distributed network hosting model layers # Connect to a distributed network hosting model layers
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -31,9 +31,9 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b> 🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
</p> </p>
🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). 🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. 🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
@ -81,9 +81,8 @@ python3 -m petals.cli.run_server petals-team/StableBeluga2
## How does it work? ## How does it work?
- Petals runs large language models like [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning. - You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Singlebatch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps.
- Single-batch inference runs at **up to 6 steps/sec** for **Llama 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. - You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.
- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
<p align="center"> <p align="center">
<img src="https://i.imgur.com/RTYF3yW.png" width="800"> <img src="https://i.imgur.com/RTYF3yW.png" width="800">
@ -113,99 +112,15 @@ Advanced guides:
- Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
- Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) - Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
## Benchmarks ### Benchmarks
The benchmarks below are for BLOOM-176B: Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
<table align="center"> ### 🛠️ Contributing
<tr>
<th colspan="2">Network</th>
<th colspan="2">Single-batch inference<br>(steps/s)</th>
<th colspan="2">Parallel forward<br>(tokens/s)</th>
</tr>
<tr>
<th rowspan="2">Bandwidth</th>
<th rowspan="2">Round-trip<br>latency</th>
<th colspan="2">Sequence length</th>
<th colspan="2">Batch size</th>
</tr>
<tr align="center">
<td>128</td>
<td>2048</td>
<td>1</td>
<td>64</td>
</tr>
<tr>
<th colspan="6">Offloading, max. possible speed on 1x A100 <sup>1</sup></th>
</tr>
<tr align="center">
<td>256 Gbit/s</td>
<td></td>
<td>0.18</td>
<td>0.18</td>
<td>2.7</td>
<td>170.3</td>
</tr>
<tr align="center">
<td>128 Gbit/s</td>
<td></td>
<td>0.09</td>
<td>0.09</td>
<td>2.4</td>
<td>152.8</td>
</tr>
<tr>
<th colspan="6">Petals on 14 heterogeneous servers across Europe and North America <sup>2</sup></th>
</tr>
<tr align="center">
<td colspan="2">Real world</td>
<td>0.83</td>
<td>0.79</td>
<td>32.6</td>
<td>179.4</td>
</tr>
<tr>
<th colspan="6">Petals on 3 servers, with one A100 each <sup>3</sup></th>
</tr>
<tr align="center">
<td>1 Gbit/s</td>
<td>&lt; 5 ms</td>
<td>1.71</td>
<td>1.54</td>
<td>70.0</td>
<td>253.6</td>
</tr>
<tr align="center">
<td>100 Mbit/s</td>
<td>&lt; 5 ms</td>
<td>1.66</td>
<td>1.49</td>
<td>56.4</td>
<td>182.0</td>
</tr>
<tr align="center">
<td>100 Mbit/s</td>
<td>100 ms</td>
<td>1.23</td>
<td>1.11</td>
<td>19.7</td>
<td>112.2</td>
</tr>
</table>
<sup>1</sup> **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch.
<sup>2</sup> **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 1001000 Mbit/s. 4 servers operate from under firewalls.
<sup>3</sup> **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU.
We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
## 🛠️ Contributing
Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing. Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
## 📜 Citation ### 📜 Citation
Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel. Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
[Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188) [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)

@ -37,7 +37,7 @@ install_requires =
accelerate>=0.22.0 accelerate>=0.22.0
huggingface-hub>=0.11.1,<1.0.0 huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3 tokenizers>=0.13.3
transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3 speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2 hivemind==1.1.10.post2
@ -47,7 +47,7 @@ install_requires =
cpufeature>=0.2.0; platform_machine == "x86_64" cpufeature>=0.2.0; platform_machine == "x86_64"
packaging>=20.9 packaging>=20.9
sentencepiece>=0.1.99 sentencepiece>=0.1.99
peft>=0.5.0 peft==0.5.0
safetensors>=0.3.1 safetensors>=0.3.1
Dijkstar>=2.6.0 Dijkstar>=2.6.0

@ -17,13 +17,13 @@ from petals.models import *
from petals.utils import * from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.1.0" __version__ = "2.3.0.dev1"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert ( assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0") version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0" ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
assert version.parse("1.1.10") <= version.parse( assert version.parse("1.1.10") <= version.parse(
hivemind.__version__ hivemind.__version__
), "Please install a proper hivemind version: pip install hivemind>=1.1.10" ), "Please install a proper hivemind version: pip install hivemind>=1.1.10"

@ -70,17 +70,17 @@ def main():
parser.add_argument('--inference_max_length', type=int, default=None, parser.add_argument('--inference_max_length', type=int, default=None,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--min_batch_size', type=int, default=1, parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)') help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=None, parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. ' help='The total number of tokens in the same batch will not exceed this value. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
parser.add_argument('--attn_cache_tokens', type=int, default=None, parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. ' help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') 'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
parser.add_argument('--cache_dir', type=str, default=None, parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')

@ -1,10 +1,14 @@
import dataclasses import dataclasses
import os
from typing import Optional, Sequence, Union from typing import Optional, Sequence, Union
from hivemind import PeerID from hivemind import PeerID
from petals.constants import PUBLIC_INITIAL_PEERS from petals.constants import PUBLIC_INITIAL_PEERS
_max_retries = os.getenv("PETALS_MAX_RETRIES")
DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
@dataclasses.dataclass @dataclasses.dataclass
class ClientConfig: class ClientConfig:
@ -21,7 +25,7 @@ class ClientConfig:
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds

@ -6,7 +6,6 @@ import tempfile
from contextvars import ContextVar from contextvars import ContextVar
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from transformers import BloomPreTrainedModel, modeling_utils from transformers import BloomPreTrainedModel, modeling_utils
@ -22,21 +21,14 @@ class FromPretrainedMixin:
model_name_or_path: Union[str, os.PathLike, None], model_name_or_path: Union[str, os.PathLike, None],
*args, *args,
low_cpu_mem_usage: Optional[bool] = None, low_cpu_mem_usage: Optional[bool] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs, **kwargs,
): ):
model_name_or_path = get_compatible_model_repo(model_name_or_path) model_name_or_path = get_compatible_model_repo(model_name_or_path)
if low_cpu_mem_usage is None: if low_cpu_mem_usage is None:
low_cpu_mem_usage = True low_cpu_mem_usage = True
if torch_dtype is None:
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
torch_dtype = "auto"
with ignore_keys(cls._keys_to_ignore_on_load_unexpected): with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
return super().from_pretrained( return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)", "low_cpu_mem_usage(`bool`, *optional*)",

@ -305,11 +305,21 @@ class InferenceSession:
else: else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
inputs_device = inputs.device inputs_device = inputs.device
inputs_dtype = inputs.dtype inputs_dtype = inputs.dtype
inputs = inputs.cpu() inputs = inputs.cpu()
prompts = prompts.cpu() prompts = prompts.cpu()
hypo_ids = hypo_ids.cpu()
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
n_input_tokens = inputs.shape[1] n_input_tokens = inputs.shape[1]

@ -1,8 +1,7 @@
import dataclasses import dataclasses
import platform import platform
from typing import Optional, Union from typing import Union
import psutil
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
@ -68,11 +67,10 @@ class LMHead(nn.Module):
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
if not self._bf16_warning_shown: if not self._bf16_warning_shown:
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: logger.warning(
logger.warning( "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" )
)
self._bf16_warning_shown = True self._bf16_warning_shown = True
hidden_states = hidden_states.float() hidden_states = hidden_states.float()

@ -1,17 +1,15 @@
import dataclasses import dataclasses
import time import time
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar from typing import Iterable, List, Optional, Tuple
from hivemind import get_logger from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
logger = get_logger(__name__) logger = get_logger(__name__)
T = TypeVar("T")
@dataclasses.dataclass @dataclasses.dataclass
class RemoteSequenceInfo: class RemoteSequenceInfo:
""" """
@ -30,7 +28,7 @@ class RemoteSequenceInfo:
last_updated_time: Optional[float] last_updated_time: Optional[float]
@classmethod @classmethod
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T: def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
block_uids = tuple(block_uids) block_uids = tuple(block_uids)
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
empty_spans = tuple([] for _ in range(len(block_uids))) empty_spans = tuple([] for _ in range(len(block_uids)))
@ -39,7 +37,7 @@ class RemoteSequenceInfo:
def __getitem__(self, ix: slice): def __getitem__(self, ix: slice):
assert isinstance(ix, slice) assert isinstance(ix, slice)
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix] block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
spans_by_priority, spans_containing_block = self.compute_spans(block_infos) spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
return RemoteSequenceInfo( return RemoteSequenceInfo(
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
) )
@ -47,60 +45,23 @@ class RemoteSequenceInfo:
def __len__(self): def __len__(self):
return len(self.block_uids) return len(self.block_uids)
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]): def update_(self, new_block_infos: List[RemoteModuleInfo]):
assert len(new_block_infos) == len(self.block_uids) assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None: assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
logger.debug(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
continue
if not info.servers:
logger.debug(f"Found no active peers for block {uid}")
continue
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
continue
self.block_infos[block_index].servers = info.servers self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
self.last_updated_time = time.perf_counter() self.last_updated_time = time.perf_counter()
@staticmethod @staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]): def _sort_spans(block_infos: List[RemoteModuleInfo]):
closed_spans = [] spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
active_spans = {} spans_by_priority.sort(key=lambda span: span.length, reverse=True)
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.length, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos))) spans_containing_block = tuple([] for _ in range(len(block_infos)))
for span in closed_spans: for span in spans_by_priority:
for block_index in range(span.start, span.end): for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span) spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block return spans_by_priority, spans_containing_block

@ -117,7 +117,6 @@ class RemoteSequenceManager:
if state.sequence_info.last_updated_time is not None: if state.sequence_info.last_updated_time is not None:
assert block_uids == state.sequence_info.block_uids assert block_uids == state.sequence_info.block_uids
self._thread.ready.set() # no need to await the first dht fetch self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
@staticmethod @staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
@ -346,9 +345,6 @@ class RemoteSequenceManager:
) )
for block_info in new_block_infos: for block_info in new_block_infos:
if not block_info:
continue
# Apply allow and block lists # Apply allow and block lists
block_info.servers = { block_info.servers = {
peer_id: server_info peer_id: server_info

@ -11,18 +11,15 @@ UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
class ServerState(Enum): def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
OFFLINE = 0 assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
JOINING = 1 dht_prefix, index = uid.split(UID_DELIMITER)
ONLINE = 2 return dht_prefix, int(index)
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass @pydantic.dataclasses.dataclass
class ModelInfo: class ModelInfo:
num_blocks: int num_blocks: pydantic.conint(ge=1, strict=True)
repository: Optional[str] = None repository: Optional[str] = None
def to_dict(self) -> dict: def to_dict(self) -> dict:
@ -33,11 +30,23 @@ class ModelInfo:
return cls(**source) return cls(**source)
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
ONLINE = 2
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass @pydantic.dataclasses.dataclass
class ServerInfo: class ServerInfo:
state: ServerState state: ServerState
throughput: RPS throughput: RPS
start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
public_name: Optional[str] = None public_name: Optional[str] = None
version: Optional[str] = None version: Optional[str] = None
@ -83,9 +92,17 @@ class RemoteSpanInfo:
server_info: ServerInfo server_info: ServerInfo
@property @property
def length(self): def length(self) -> int:
return self.end - self.start return self.end - self.start
@property
def state(self) -> ServerState:
return self.server_info.state
@property
def throughput(self) -> float:
return self.server_info.throughput
RPCInfo = Dict[str, Any] RPCInfo = Dict[str, Any]

@ -31,6 +31,9 @@ class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig,
def from_pretrained( def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
): ):
if "180B" in model_name_or_path.upper():
logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license")
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None: if loading_from_repo and dht_prefix is None:
dht_prefix = str(model_name_or_path) dht_prefix = str(model_name_or_path)

@ -47,6 +47,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None, past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
@ -68,6 +69,9 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
assert ( assert (
attention_mask is None or (attention_mask == 1).all() attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}" ), f"Custom attention masks are not supported, {attention_mask=}"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported"

@ -3,13 +3,219 @@ LLaMA intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
See commit history for authorship. See commit history for authorship.
""" """
import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
repeat_kv,
rotate_half,
)
from petals.utils.cuda_graphs import make_inference_graphed_callable
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class OptimizedLlamaAttention(LlamaAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._rotary_graph = None
def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
if self._rotary_graph is None:
self._rotary_graph = make_inference_graphed_callable(
apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
)
return self._rotary_graph(query_states, key_states, cos, sin)
class WrappedLlamaBlock(LlamaDecoderLayer): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
assert position_ids is None
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_attn_graph = None
self.post_attn_graph = None
def _optimized_input_layernorm(self, hidden_states):
if self.pre_attn_graph is None:
self.pre_attn_graph = make_inference_graphed_callable(
self.input_layernorm.forward, sample_args=(hidden_states,)
)
return self.pre_attn_graph(hidden_states)
def _optimized_output_layernorm(self, hidden_states):
if self.post_attn_graph is None:
self.post_attn_graph = make_inference_graphed_callable(
self.post_attention_layernorm.forward, sample_args=(hidden_states,)
)
return self.post_attn_graph(hidden_states)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
hidden_states = self._optimized_input_layernorm(hidden_states)
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
hidden_states = self._optimized_output_layernorm(hidden_states)
else:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
seq_length_with_past = seq_length_with_past + past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
if position_ids is None: assert position_ids is None
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:

@ -1,54 +1,23 @@
from dataclasses import dataclass from typing import Dict, List
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
from hivemind import PeerID, get_logger from hivemind import PeerID, get_logger
from petals.data_structures import RemoteModuleInfo, ServerState from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
class Span: # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
start: int # If the order were not defined, we would get slightly different values due to floating point errors,
end: int # which may cause excess block replacements.
throughput: float
state: ServerState
@property
def length(self):
return self.end - self.start
def move_to(self, new_start: int) -> None:
self.start, self.end = new_start, new_start + self.length
def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
spans = {}
throughputs = np.zeros(len(module_infos))
for block, module in enumerate(module_infos):
if module is None:
continue
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
for peer_id, server in sorted(module.servers.items()):
if server.state == ServerState.OFFLINE:
continue
if peer_id in spans: throughputs = np.zeros(total_blocks)
spans[peer_id].start = min(spans[peer_id].start, block) for span in sorted(spans.values(), key=lambda span: span.peer_id):
spans[peer_id].end = max(spans[peer_id].start, block + 1) throughputs[span.start : span.end] += span.throughput
else: return throughputs
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
throughputs[block] += server.throughput
return spans, throughputs
def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
return min(options)[-1] return min(options)[-1]
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
_, throughputs = compute_spans(module_infos) spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
start = _choose_best_start(throughputs, num_blocks) start = _choose_best_start(throughputs, num_blocks)
return list(range(start, start + num_blocks)) return list(range(start, start + num_blocks))
def _move_span(span: RemoteSpanInfo, new_start: int):
span.start, span.end = new_start, new_start + span.length
def should_choose_other_blocks( def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
) -> bool: ) -> bool:
if balance_quality > 1.0: if balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes) return True # Forces rebalancing on each check (may be used for debugging purposes)
spans, throughputs = compute_spans(module_infos) spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
initial_throughput = throughputs.min() initial_throughput = throughputs.min()
eps = 1e-3 eps = 1e-3
@ -88,7 +64,7 @@ def should_choose_other_blocks(
return False # This server is on its best place already return False # This server is on its best place already
throughputs[local_span.start : local_span.end] += local_span.throughput * eps throughputs[local_span.start : local_span.end] += local_span.throughput * eps
local_span.move_to(new_start) _move_span(local_span, new_start)
throughputs[local_span.start : local_span.end] += local_span.throughput throughputs[local_span.start : local_span.end] += local_span.throughput
moved = True moved = True
@ -105,7 +81,7 @@ def should_choose_other_blocks(
throughputs[span.start : span.end] += span.throughput * eps throughputs[span.start : span.end] += span.throughput * eps
if span.start != new_start: if span.start != new_start:
span.move_to(new_start) _move_span(span, new_start)
moved = True moved = True
throughputs[span.start : span.end] += span.throughput throughputs[span.start : span.end] += span.throughput

@ -24,7 +24,7 @@ from transformers import PretrainedConfig
import petals import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid
from petals.server import block_selection from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -204,7 +204,7 @@ class Server:
# For attention cache in GPU or RAM # For attention cache in GPU or RAM
if attn_cache_tokens is None: if attn_cache_tokens is None:
attn_cache_tokens = 32768 if is_multiquery_attn else 8192 attn_cache_tokens = 16384 if is_multiquery_attn else 4096
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype) self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
@ -221,11 +221,10 @@ class Server:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers) num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
if block_indices is not None: if block_indices is not None:
try: try:
first_block_index, last_block_index = block_indices.split(":") start_block, end_block = [int(index.strip()) for index in block_indices.split(":")]
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
except Exception as e: except Exception as e:
raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)") raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
block_indices = range(first_block_index, last_block_index) block_indices = range(start_block, end_block)
num_blocks = len(block_indices) num_blocks = len(block_indices)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks self.strict_block_indices, self.num_blocks = block_indices, num_blocks
@ -704,11 +703,16 @@ class ModuleAnnouncerThread(threading.Thread):
self.expiration = expiration self.expiration = expiration
self.trigger = threading.Event() self.trigger = threading.Event()
self.dht_prefix = parse_uid(module_uids[0])[0]
block_indices = [parse_uid(uid)[1] for uid in module_uids]
self.server_info.start_block = min(block_indices)
self.server_info.end_block = max(block_indices) + 1
self.max_pinged = max_pinged self.max_pinged = max_pinged
self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0] self.next_uids = [
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] f"{self.dht_prefix}{UID_DELIMITER}{i}"
start_block, end_block = min(block_indices), max(block_indices) + 1 for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)
self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] ]
self.ping_aggregator = PingAggregator(self.dht) self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None: def run(self) -> None:
@ -756,12 +760,11 @@ class ModuleAnnouncerThread(threading.Thread):
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True) module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers} middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged)) pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
pinged_servers.discard(self.dht.peer_id) pinged_servers.discard(self.dht.peer_id)
if module_infos[-1] is not None: # Sample servers hosting the block after the last one (most likely continuations) separately
# Sample servers hosting the block after the last one (most likely continuations) separately pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
self.ping_aggregator.ping(list(pinged_servers)) self.ping_aggregator.ping(list(pinged_servers))

@ -56,7 +56,7 @@ def get_server_throughput(
# We use the system-wide lock since only one process at a time can measure the host throughput # We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True) os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb") as lock_fd: with open(lock_path, "wb+") as lock_fd:
logger.info("Loading throughput info") logger.info("Loading throughput info")
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
# The OS will release the lock when lock_fd is closed or the process is killed # The OS will release the lock when lock_fd is closed or the process is killed

@ -0,0 +1,76 @@
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
"""Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
assert not isinstance(callable, torch.nn.Module)
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
raise RuntimeError(
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
)
flatten_arg, _ = _tree_flatten(sample_args)
flatten_sample_args = tuple(flatten_arg)
assert all(
isinstance(arg, torch.Tensor) for arg in flatten_arg
), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
len_user_args = len(sample_args)
static_input_surface = flatten_sample_args
graph = torch.cuda.CUDAGraph()
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(callable(*sample_args))
del outputs
torch.cuda.current_stream().wait_stream(s)
# Capture forward graph
with torch.cuda.graph(graph):
outputs = callable(*sample_args)
flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
static_outputs = tuple(flatten_outputs)
def make_graphed_function(
graph,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
):
def replay_graph(*inputs):
# At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
def functionalized(*user_args):
# Runs the autograd function with inputs == all inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
flatten_user_args, _ = _tree_flatten(user_args)
out = replay_graph(*flatten_user_args)
return _tree_unflatten(out, output_unflatten_spec)
return functionalized
# Put together the final graphed callable
graphed = make_graphed_function(
graph,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
)
return graphed

@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo from petals.data_structures import (
CHAIN_DELIMITER,
UID_DELIMITER,
ModuleUID,
RemoteModuleInfo,
RemoteSpanInfo,
ServerInfo,
ServerState,
parse_uid,
)
logger = get_logger(__name__) logger = get_logger(__name__)
@ -70,7 +79,7 @@ def get_remote_module_infos(
*, *,
latest: bool = False, latest: bool = False,
return_future: bool = False, return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: ) -> Union[List[RemoteModuleInfo], MPFuture]:
return dht.run_coroutine( return dht.run_coroutine(
partial( partial(
_get_remote_module_infos, _get_remote_module_infos,
@ -90,7 +99,7 @@ async def _get_remote_module_infos(
active_adapter: Optional[str], active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration], expiration_time: Optional[DHTExpiration],
latest: bool, latest: bool,
) -> List[Optional[RemoteModuleInfo]]: ) -> List[RemoteModuleInfo]:
if latest: if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf expiration_time = math.inf
@ -99,14 +108,14 @@ async def _get_remote_module_infos(
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids) modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
for i, uid in enumerate(uids): for module_info in modules:
metadata = found[uid] metadata = found[module_info.uid]
if metadata is None or not isinstance(metadata.value, dict): if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None: if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}") logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
continue continue
servers = {}
for peer_id, server_info in metadata.value.items(): for peer_id, server_info in metadata.value.items():
try: try:
peer_id = PeerID.from_base58(peer_id) peer_id = PeerID.from_base58(peer_id)
@ -116,9 +125,29 @@ async def _get_remote_module_infos(
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue continue
servers[peer_id] = server_info module_info.servers[peer_id] = server_info
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules return modules
def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
num_blocks = len(module_infos)
spans = {}
for block_idx, module_info in enumerate(module_infos):
for peer_id, server_info in sorted(module_info.servers.items()):
if server_info.state.value < min_state.value:
continue
if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
)
if server_info.start_block is not None and server_info.end_block is not None:
spans[peer_id].start = max(server_info.start_block - block_offset, 0)
spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
elif spans[peer_id].state == server_info.state:
spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
return spans

@ -22,7 +22,7 @@ def _blocks_lock(cache_dir: Optional[str], mode: int):
lock_path = Path(cache_dir, BLOCKS_LOCK_FILE) lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)
os.makedirs(lock_path.parent, exist_ok=True) os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb") as lock_fd: with open(lock_path, "wb+") as lock_fd:
fcntl.flock(lock_fd.fileno(), mode) fcntl.flock(lock_fd.fileno(), mode)
# The OS will release the lock when lock_fd is closed or the process is killed # The OS will release the lock when lock_fd is closed or the process is killed
yield yield

@ -3,6 +3,7 @@ from typing import Optional, Tuple
import pytest import pytest
import torch import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from petals.utils.auto_config import AutoDistributedConfig from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block from petals.utils.convert_block import QuantType, convert_block
@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
return state return state
@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
past_key_value = layer_past
if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
**kwargs,
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_llama_to_bloom(
present_key_value, batch_size, seq_length_with_past
)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
key_states = key_states.view(*value_states.shape)
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) @pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked @pytest.mark.forked
def test_falcon(device): def test_optimized_block(device):
if device == "cuda:0" and not torch.cuda.is_available(): if device == "cuda:0" and not torch.cuda.is_available():
pytest.skip("CUDA tests can be run only in CUDA-enabled setups") pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
@ -108,11 +190,17 @@ def test_falcon(device):
quant_type = QuantType.NONE quant_type = QuantType.NONE
block = config.block_class(config).to(dtype) block = config.block_class(config).to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
unopt_block = convert_block( unopt_block = convert_block(
unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
) )
unopt_block.load_state_dict(block.state_dict()) unopt_block.load_state_dict(block.state_dict())

Loading…
Cancel
Save