Compare commits

..

No commits in common. 'main' and 'v2.0.0.post1' have entirely different histories.

@ -42,17 +42,6 @@ jobs:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Free disk space on Ubuntu runner
uses: kfir4444/free-disk-space@main
with:
# found in: https://github.com/docker/build-push-action/issues/968
tool-cache: false
android: true
dotnet: true
haskell: true
large-packages: true
swap-storage: true
- name: Build and push
id: docker_build
uses: docker/build-push-action@v2

@ -7,26 +7,13 @@ on:
jobs:
run-tests:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }
python-version: [ '3.8', '3.9', '3.10' ]
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 20
timeout-minutes: 15
steps:
- name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
@ -44,77 +31,43 @@ jobs:
pip install .[dev]
- name: Test
run: |
set -x # Print executed commands
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export MODEL_NAME=bigscience/bloom-560m
export REF_NAME=bigscience/bloom-560m
export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log &
SERVER1_PID=$!
python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
BOOTSTRAP_PID=$!
sleep 5 # wait for the first server to initialize DHT
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
--device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for the 1st server to choose blocks
$RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
SERVER2_PID=$!
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
SERVER3_PID=$!
# ^-- chunking test
$RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)
sleep 5 # wait for the log files to appear
tail -n 100 -f bootstrap.log server*.log &
tail -n 100 -f server*.log &
LOGGER_PID=$!
sleep 30 # wait for servers to download layers
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
# [Step 2] Run PyTest
# Share disk cache between Petals servers, clients, and HF Transformers
export TRANSFORMERS_CACHE=~/.cache/petals
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
# Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely
export PETALS_MAX_RETRIES=10
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
pytest tests --durations=0 --durations-min=1.0 -v
# [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3
python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --n_steps 1
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
# [Step 4] Clean up
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
echo "Done!"

@ -1,27 +1,24 @@
<p align="center">
<img src="https://i.imgur.com/7eR7Pan.png" width="400"><br>
Run large language models at home, BitTorrent-style.<br>
Fine-tuning and inference <a href="https://github.com/bigscience-workshop/petals#benchmarks">up to 10x faster</a> than offloading
<br><br>
<a href="https://pypi.org/project/petals/"><img src="https://img.shields.io/pypi/v/petals.svg?color=green"></a>
<a href="https://discord.gg/tfHfe8B34k"><img src="https://img.shields.io/discord/865254854262652969?label=discord&logo=discord&logoColor=white"></a>
<br>
Fine-tuning and inference <a href="https://github.com/bigscience-workshop/petals#benchmarks">up to 10x faster</a> than offloading<br><br>
<a href="https://pypi.org/project/petals/"><img src="https://img.shields.io/pypi/v/petals.svg?color=green"></a><br>
</p>
Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
Generate text with distributed [LLaMA 2](https://ai.meta.com/llama/) ([70B](https://huggingface.co/meta-llama/Llama-2-70b-hf), [70B-Chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
```python
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
# Choose any model available at https://health.petals.dev
model_name = "petals-team/StableBeluga2" # This one is fine-tuned Llama 2 (70B)
model_name = "enoch/llama-65b-hf"
# You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf",
# "bigscience/bloom", or "bigscience/bloomz"
# Connect to a distributed network hosting model layers
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet
# Run the model as if it were on your computer
inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
outputs = model.generate(inputs, max_new_tokens=5)
print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
@ -31,96 +28,175 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
</p>
🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
📋 Make sure you follow the model's terms of use (see [LLaMA 2](https://bit.ly/llama2-license), [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license) licenses).
🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
🔏 Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
### Connect your GPU and increase Petals capacity
## Connect your GPU and increase Petals capacity
Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out [available models](https://health.petals.dev) and help serving one of them! As an example, here is how to host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your GPU:
🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+):
```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install git+https://github.com/bigscience-workshop/petals
python -m petals.cli.run_server petals-team/StableBeluga2
pip install --upgrade petals
python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b
```
🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
Or run our [Docker](https://www.docker.com) image (works on Linux, macOS, and Windows with [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)):
```bash
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
learningathome/petals:main \
python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
python -m petals.cli.run_server --port 31330 enoch/llama-65b-hf --adapters timdettmers/guanaco-65b
```
🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
```bash
brew install python
python3 -m pip install git+https://github.com/bigscience-workshop/petals
python3 -m petals.cli.run_server petals-team/StableBeluga2
```
🔒 Hosting a server does not allow others to run custom code on your computer. Learn more about security [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
<p align="center">
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
</p>
💬 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)!
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
### Check out tutorials, examples, and more
🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
Basic tutorials:
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
- Prompt-tune LLaMA-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
Useful tools and advanced guides:
- [Chatbot web app](http://chat.petals.ml) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/borzunov/chat.petals.ml)
- [Monitor](http://health.petals.ml) for the public swarm: [source code](https://github.com/borzunov/health.petals.ml)
- Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
- Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
Learning more:
- Frequently asked questions: [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions)
- In-depth system description: [paper](https://arxiv.org/abs/2209.01188)
## How does it work?
- You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Singlebatch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps.
- You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.
- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning.
- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](http://chat.petals.ml) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
<p align="center">
<img src="https://i.imgur.com/RTYF3yW.png" width="800">
</p>
<p align="center">
📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
</p>
## 📚 Tutorials, examples, and more
Basic tutorials:
- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
- Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
Useful tools:
- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
Advanced guides:
- Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
- Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
## Installation
### Benchmarks
Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/distribution) on Linux:
Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install --upgrade petals
```
### 🛠️ Contributing
If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client).
## Benchmarks
The benchmarks below are for BLOOM-176B:
<table align="center">
<tr>
<th colspan="2">Network</th>
<th colspan="2">Single-batch inference<br>(steps/s)</th>
<th colspan="2">Parallel forward<br>(tokens/s)</th>
</tr>
<tr>
<th rowspan="2">Bandwidth</th>
<th rowspan="2">Round-trip<br>latency</th>
<th colspan="2">Sequence length</th>
<th colspan="2">Batch size</th>
</tr>
<tr align="center">
<td>128</td>
<td>2048</td>
<td>1</td>
<td>64</td>
</tr>
<tr>
<th colspan="6">Offloading, max. possible speed on 1x A100 <sup>1</sup></th>
</tr>
<tr align="center">
<td>256 Gbit/s</td>
<td></td>
<td>0.18</td>
<td>0.18</td>
<td>2.7</td>
<td>170.3</td>
</tr>
<tr align="center">
<td>128 Gbit/s</td>
<td></td>
<td>0.09</td>
<td>0.09</td>
<td>2.4</td>
<td>152.8</td>
</tr>
<tr>
<th colspan="6">Petals on 14 heterogeneous servers across Europe and North America <sup>2</sup></th>
</tr>
<tr align="center">
<td colspan="2">Real world</td>
<td>0.83</td>
<td>0.79</td>
<td>32.6</td>
<td>179.4</td>
</tr>
<tr>
<th colspan="6">Petals on 3 servers, with one A100 each <sup>3</sup></th>
</tr>
<tr align="center">
<td>1 Gbit/s</td>
<td>&lt; 5 ms</td>
<td>1.71</td>
<td>1.54</td>
<td>70.0</td>
<td>253.6</td>
</tr>
<tr align="center">
<td>100 Mbit/s</td>
<td>&lt; 5 ms</td>
<td>1.66</td>
<td>1.49</td>
<td>56.4</td>
<td>182.0</td>
</tr>
<tr align="center">
<td>100 Mbit/s</td>
<td>100 ms</td>
<td>1.23</td>
<td>1.11</td>
<td>19.7</td>
<td>112.2</td>
</tr>
</table>
<sup>1</sup> **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch.
<sup>2</sup> **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 1001000 Mbit/s. 4 servers operate from under firewalls.
<sup>3</sup> **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU.
We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
## 🛠️ Contributing
Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
### 📜 Citation
## 📜 Citation
Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
[Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
@ -142,5 +218,5 @@ _arXiv preprint arXiv:2209.01188,_ 2022.
This project is a part of the <a href="https://bigscience.huggingface.co/">BigScience</a> research workshop.
</p>
<p align="center">
<img src="https://petals.dev/bigscience.png" width="150">
<img src="https://petals.ml/bigscience.png" width="150">
</p>

@ -15,15 +15,15 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--n_steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
args = parser.parse_args()
if args.n_processes == "n_gpus":
@ -31,19 +31,15 @@ def main():
else:
args.n_processes = int(args.n_processes)
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
logger.info(f"Final result: {speed=:.2f}")
@torch.inference_mode()
def benchmark_forward(process_idx, args, result_pipe):
def benchmark_forward(process_idx, args):
model = AutoDistributedModel.from_pretrained(
args.model,
initial_peers=args.initial_peers,
@ -68,7 +64,7 @@ def benchmark_forward(process_idx, args, result_pipe):
speed = input_ids.numel() / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
result_pipe.send(speed)
logger.info(f"Final result: {process_idx=} {speed=:.2f}")
if __name__ == "__main__":

@ -16,13 +16,13 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--warmup_steps", type=int, default=1)
args = parser.parse_args()
if args.n_processes == "n_gpus":
@ -30,19 +30,15 @@ def main():
else:
args.n_processes = int(args.n_processes)
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
logger.info(f"Final result: {speed=:.2f}")
@torch.inference_mode()
def benchmark_inference(process_idx, args, result_pipe):
def benchmark_inference(process_idx, args):
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
@ -65,7 +61,7 @@ def benchmark_inference(process_idx, args, result_pipe):
speed = 1 / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
result_pipe.send(speed)
logger.info(f"Final result: {process_idx=} {speed=:.2f}")
if __name__ == "__main__":

@ -15,18 +15,18 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
parser.add_argument("--task", type=str, default="cls", help="Training task type")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--task", type=str, default="cls")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--pre_seq_len", type=int, default=16)
parser.add_argument("--n_steps", type=int, default=10)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
args = parser.parse_args()
assert args.task in ["cls", "causal_lm"]
@ -36,18 +36,14 @@ def main():
else:
args.n_processes = int(args.n_processes)
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
def benchmark_training(process_idx, args, result_pipe):
def benchmark_training(process_idx, args):
if args.task == "cls":
model = AutoDistributedModelForSequenceClassification.from_pretrained(
args.model,
@ -100,7 +96,7 @@ def benchmark_training(process_idx, args, result_pipe):
bwd_speed = input_ids.numel() / np.mean(bwd_times)
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
result_pipe.send((fwd_speed, bwd_speed))
logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
if __name__ == "__main__":

@ -92,6 +92,9 @@
},
"outputs": [],
"source": [
"# Choose a model you'd like to prompt-tune. We recommend starting with\n",
"# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n",
"# The code below uses LLaMA-65B.\n",
"MODEL_NAME = \"enoch/llama-65b-hf\"\n",
"\n",
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
@ -327,7 +330,7 @@
"id": "51770911"
},
"source": [
"Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](https://health.petals.dev/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!"
"Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](http://health.petals.ml/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!"
]
},
{

@ -18,7 +18,6 @@ classifiers =
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Mathematics
Topic :: Scientific/Engineering :: Artificial Intelligence
@ -30,24 +29,24 @@ classifiers =
package_dir =
= src
packages = find:
python_requires = >=3.8
python_requires = >=3.8,<3.11
install_requires =
torch>=1.12,<2.3.0
bitsandbytes==0.41.1
accelerate>=0.27.2
torch>=1.12
bitsandbytes==0.40.1.post1
accelerate>=0.20.3,<0.21.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers==4.38.2 # if you change this, please also change version assert in petals/__init__.py
transformers>=4.31.0,<5.0.0
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8
hivemind==1.1.8
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2
cpufeature>=0.2.0; platform_machine == "x86_64"
cpufeature>=0.2.0
packaging>=20.9
sentencepiece>=0.1.99
peft==0.5.0
peft>=0.4.0
safetensors>=0.3.1
Dijkstar>=2.6.0

@ -1,13 +1,7 @@
import os
import platform
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
if platform.system() == "Darwin":
# Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
os.environ.setdefault("no_proxy", "*")
os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
import hivemind
import transformers
from packaging import version
@ -17,13 +11,13 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.3.0.dev2"
__version__ = "2.0.0.post1"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0"
version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
def _override_bfloat16_mode_default():

@ -7,8 +7,8 @@ This script may be used for launching lightweight CPU machines serving as bootst
This may be eventually merged to the hivemind upstream.
"""
import argparse
import time
from argparse import ArgumentParser
from secrets import token_hex
from hivemind.dht import DHT, DHTNode
@ -35,7 +35,7 @@ async def report_status(dht: DHT, node: DHTNode):
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = ArgumentParser()
parser.add_argument(
"--initial_peers",
nargs="*",
@ -73,9 +73,7 @@ def main():
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
)
parser.add_argument(
"--use_auto_relay",
action="store_true",
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
"--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls"
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"

@ -1,10 +1,8 @@
import argparse
import logging
import configargparse
import torch
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils import limits
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger
from humanfriendly import parse_size
@ -70,17 +68,15 @@ def main():
parser.add_argument('--inference_max_length', type=int, default=None,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. '
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
@ -98,22 +94,21 @@ def main():
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--max_alloc_timeout', type=float, default=600,
help="If the cache is full, the server will wait for memory to be freed up to this many seconds"
" before rejecting the request")
parser.add_argument('--alloc_timeout', type=float, default=5,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
parser.add_argument('--revision', type=str, default=None,
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
type=lambda value: value if value in ['auto', 'eval'] else float(value),
default='auto',
help='Expected server throughput (a float measured in RPS). '
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
'If set to "dry_run", the script re-evaluates the throughput and exits.')
parser.add_argument('--update_period', type=float, required=False, default=120,
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
parser.add_argument('--update_period', type=float, required=False, default=60,
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
@ -125,14 +120,14 @@ def main():
help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS,
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
group.add_argument('--new_swarm', action='store_true',
help='Start a new private swarm (i.e., do not connect to any initial peers)')
parser.add_argument('--increase_file_limit', type=int, default=4096,
help='On *nix, increase the max number of files a server can open '
'before hitting "Too many open files" (set to zero to keep the system limit)')
parser.add_argument('--increase_file_limit', action='store_true',
help='On *nix, this will increase the max number of processes '
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
parser.add_argument('--stats_report_interval', type=int, required=False,
help='Interval between two reports of batch processing performance statistics')
@ -157,11 +152,11 @@ def main():
"weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
parser.add_argument("--skip_reachability_check", action='store_true',
help="Skip checking this server's reachability via health.petals.dev "
help="Skip checking this server's reachability via health.petals.ml "
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
parser.add_argument("--adapters", nargs='*', default=(),
parser.add_argument("--adapters", nargs='+', default=(),
help="List of pre-loaded LoRA adapters that can be used for inference or training")
# fmt:on
@ -188,10 +183,8 @@ def main():
args["startup_timeout"] = args.pop("daemon_startup_timeout")
file_limit = args.pop("increase_file_limit")
if file_limit:
limits.logger.setLevel(logging.WARNING)
limits.increase_file_limit(file_limit, file_limit)
if args.pop("increase_file_limit"):
increase_file_limit()
compression_type = args.pop("compression").upper()
compression = getattr(CompressionType, compression_type)
@ -212,10 +205,6 @@ def main():
validate_version()
if not torch.backends.openmp.is_available():
# Necessary to prevent the server from freezing after forks
torch.set_num_threads(1)
server = Server(
**args,
host_maddrs=host_maddrs,

@ -1,4 +1,4 @@
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -1,35 +0,0 @@
import dataclasses
import os
from typing import Optional, Sequence, Union
from hivemind import PeerID
from petals.constants import PUBLIC_INITIAL_PEERS
_max_retries = os.getenv("PETALS_MAX_RETRIES")
DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
@dataclasses.dataclass
class ClientConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update

@ -3,9 +3,10 @@ import json
import os
import re
import tempfile
from contextvars import ContextVar
import threading
from typing import List, Optional, Tuple, Union
import torch
from hivemind.utils.logging import get_logger
from transformers import BloomPreTrainedModel, modeling_utils
@ -21,14 +22,21 @@ class FromPretrainedMixin:
model_name_or_path: Union[str, os.PathLike, None],
*args,
low_cpu_mem_usage: Optional[bool] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs,
):
model_name_or_path = get_compatible_model_repo(model_name_or_path)
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if torch_dtype is None:
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
torch_dtype = "auto"
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
return super().from_pretrained(
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
@ -39,16 +47,18 @@ class FromPretrainedMixin:
)
_ignored_keys = ContextVar("ignored_keys", default=None)
_shard_config = threading.local()
_shard_config.ignored_keys = None
@contextlib.contextmanager
def ignore_keys(patterns: List[str]):
token = _ignored_keys.set(patterns)
try:
prev_patterns = _shard_config.ignored_keys
_shard_config.ignored_keys = patterns
yield
finally:
_ignored_keys.reset(token)
_shard_config.ignored_keys = prev_patterns
def patched_get_checkpoint_shard_files(
@ -56,7 +66,7 @@ def patched_get_checkpoint_shard_files(
) -> Tuple[List[str], dict]:
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
should_ignore_keys = _ignored_keys.get() is not None
should_ignore_keys = _shard_config.ignored_keys is not None
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
with tempdir_ctx as tempdir:
if should_ignore_keys:
@ -67,7 +77,7 @@ def patched_get_checkpoint_shard_files(
index["weight_map"] = {
param_name: filename
for param_name, filename in index["weight_map"].items()
if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
}
n_loaded_shards = len(set(index["weight_map"].values()))
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")

@ -7,18 +7,22 @@ import uuid
from typing import AsyncIterator, List, Optional, Tuple
import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind import (
MSGPackSerializer,
anext,
deserialize_torch_tensor,
get_logger,
nested_flatten,
serialize_torch_tensor,
)
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__name__)
@ -32,7 +36,7 @@ class _ServerInferenceSession:
def __init__(
self,
config: ClientConfig,
config: SequenceManagerConfig,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
@ -59,7 +63,7 @@ class _ServerInferenceSession:
@classmethod
async def create(
cls,
config: ClientConfig,
config: SequenceManagerConfig,
p2p: P2P,
span: RemoteSpanInfo,
uid: ModuleUID,
@ -71,7 +75,7 @@ class _ServerInferenceSession:
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.connect_timeout,
config.request_timeout,
)
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
@ -84,7 +88,12 @@ class _ServerInferenceSession:
break # this message means "done sending"
def step(
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
*,
step_id: str,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
@ -109,8 +118,23 @@ class _ServerInferenceSession:
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
input_tensors = (inputs, prompts, hypo_ids)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
@ -120,25 +144,13 @@ class _ServerInferenceSession:
if next_servers:
request_metadata["next_servers"] = next_servers
request_metadata["args_structure"] = args_structure
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, inference_schema)
for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
],
metadata=MSGPackSerializer.dumps(request_metadata),
)
@ -210,8 +222,7 @@ class InferenceSession:
self._server_sessions = []
self._position = 0
self._max_length = max_length
self.output_ids = None
self.past_key_values = None
self.last_token_id = None
@property
def num_blocks(self) -> int:
@ -256,9 +267,7 @@ class InferenceSession:
assert not self._closed and not self._server_sessions
return self
def step(
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@ -268,21 +277,11 @@ class InferenceSession:
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
hypo_ids = hypo_ids.cpu()
step_id = str(uuid.uuid4())
n_input_tokens = inputs.shape[1]
@ -303,7 +302,7 @@ class InferenceSession:
server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
)
server_idx += 1
@ -336,7 +335,7 @@ class InferenceSession:
n_prev_spans = len(self._server_sessions)
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
if attempt_no >= 1:
logger.debug(
logger.info(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)
@ -370,13 +369,3 @@ class InferenceSession:
def __del__(self):
self.close()
@property
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
return self.output_ids[:, -1:] if self.output_ids is not None else None
@last_token_id.setter
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
if self.output_ids is None:
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
self.output_ids[:, -1:] = value

@ -1,7 +1,8 @@
import dataclasses
import platform
from typing import Union
from typing import Optional, Union
import psutil
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
@ -67,10 +68,11 @@ class LMHead(nn.Module):
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
if not self._bf16_warning_shown:
logger.warning(
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
)
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
logger.warning(
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
"Consider loading the model with torch_dtype='float32'"
)
self._bf16_warning_shown = True
hidden_states = hidden_states.float()

@ -76,9 +76,9 @@ def force_non_empty_weights():
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
"""
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
try:
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
yield
finally:
nn.Module.register_parameter = possibly_patched_register_parameter

@ -12,55 +12,53 @@ from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_U
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.config import ClientConfig
from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=config.request_timeout,
timeout=timeout,
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=config.request_timeout,
timeout=timeout,
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout)
outputs = aiter_with_timeout(outputs, config.request_timeout)
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
outputs = aiter_with_timeout(outputs, timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout)
grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout)
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
grad_inputs = aiter_with_timeout(grad_inputs, timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
@ -69,7 +67,7 @@ async def run_remote_forward(
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
config: ClientConfig,
timeout: float,
metadata: Optional[bytes] = None,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
@ -85,20 +83,26 @@ async def run_remote_forward(
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
forward_inputs = (inputs, kwargs)
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
compression = args_schema[0].compression
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: create more explicit way to check servers schema and client's structure
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, forward_schema)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
)
)
@ -106,7 +110,7 @@ async def run_remote_forward(
size = sum(t.element_size() * t.nelement() for t in inputs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
@ -114,8 +118,10 @@ async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
*inputs_and_grad_outputs: torch.Tensor,
config: ClientConfig,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
timeout: float,
metadata: Optional[bytes] = None,
**kwargs,
) -> Sequence[torch.Tensor]:
@ -124,14 +130,16 @@ async def run_remote_backward(
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
outputs_schema = rpc_info["outputs_schema"]
compression = args_schema[0].compression
backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
# TODO: create more explicit way to check servers schema and client's structure
assert (
len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
# Asynchronous serialization
loop = asyncio.get_running_loop()
@ -145,5 +153,5 @@ async def run_remote_backward(
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
return deserialized_grad_inputs

@ -1,164 +1,349 @@
import contextlib
import dataclasses
from contextvars import ContextVar
from typing import Any, ContextManager, Dict, List, Optional, Tuple
from typing import List, Optional
import torch
import transformers
from hivemind.utils.logging import get_logger
from torch import Tensor
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from
from petals.utils.generation_algorithms import (
BeamSearchAlgorithm,
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
SamplingAlgorithm,
TopKAlgorithm,
)
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
logger = get_logger(__name__)
class RemotePastKeyValues(Cache):
"""only keeps the number of seen tokens. pretends to be a legit cache"""
def __init__(self) -> None:
super().__init__()
self.seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens
def get_max_length(self) -> Optional[int]:
return None
def update_seen(self, new_seen: int) -> None:
self.seen_tokens += new_seen
def reorder_cache(self, beam_idx):
raise NotImplementedError("Beam search reordering is not implemented yet")
_skipped_tokens = ContextVar("skipped_tokens", default=0)
class _SkipTokensMixin:
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
# due to how transformers.PreTrainedModel.can_generate() works
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
return super().prepare_inputs_for_generation(input_ids, **kwargs)
class RemoteGenerationMixin(_SkipTokensMixin):
class RemoteGenerationMixin:
"""
This class is an upgrade to `transformers.GenerationMixin` that:
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
you don't have to rerun the prefix through all the servers to generate each new token
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
The class exposes can be used for:
- *greedy decoding*.
- *multinomial, top-k and top-p sampling*.
- *beam-search decoding*
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
However, it has some differences for remote usage.
"""
@docstring_from(RemoteSequential.active_session)
@property
def active_session(self) -> Optional[InferenceSession]:
return self.transformer.h.active_session
def inference_session(self, **kwargs) -> InferenceSession:
"""
Returns an inference session for the model's RemoteSequential module.
@docstring_from(RemoteSequential.use_session)
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
@docstring_from(RemoteSequential.inference_session)
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
return self.transformer.h.inference_session(**kwargs)
@docstring_from(transformers.GenerationMixin.generate.__doc__)
@torch.inference_mode()
def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
self._fix_generate_kwargs(kwargs)
if inputs is None:
inputs = kwargs.pop("input_ids", None)
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext(self.active_session)
self,
inputs: Optional[torch.Tensor] = None,
*,
do_sample: Optional[bool] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
num_beams: Optional[int] = 1,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
num_return_sequences: Optional[int] = None,
session: Optional[InferenceSession] = None,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head.
:param inputs: The input tokens to the model.
:param do_sample: Whether to sample from the model predictions or take the argmax.
:param temperature: The temperature to use for sampling.
:param top_k: The number of results to return.
:param top_p: The cumulative probability of results to return.
:param num_beams: The number of beams to use for beam search.
:param bos_token_id: The id of the beginning of sentence token.
:param eos_token_id: The id of the end of sentence token.
:param pad_token_id: The id of the padding token.
:param max_length: The maximum number of tokens in the output (including input tokens).
:param max_new_tokens: The maximum number of tokens to generate.
:param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use.
:param num_return_sequences: How many hypothesis from the beam will be in output.
"""
prefix_length = 0 if inputs is None else inputs.size(1)
prefix_length += self.config.pre_seq_len
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
if max_length is not None and max_new_tokens is None:
max_new_tokens = max_length - prefix_length
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
elif max_length is None and max_new_tokens is not None:
max_length = prefix_length + max_new_tokens
resuming_session = session is not None and session.last_token_id is not None
if num_beams > 1 and resuming_session:
raise NotImplementedError(
"Resuming inference session in .generate() along with beam search is not supported yet"
)
if inputs is not None:
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
if resuming_session:
inputs = torch.cat([session.last_token_id, inputs], dim=1)
else:
# If there's no active session, create a new one
max_length = kwargs.get("max_length")
max_new_tokens = kwargs.get("max_new_tokens")
assert (max_length is None) != (
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
session_max_length = self.transformer.config.pre_seq_len
if max_length is not None:
session_max_length += max_length
if resuming_session:
inputs = session.last_token_id
else:
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
batch_size = inputs.size(0)
if decoding_algorithm is None:
if do_sample:
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
elif num_beams is not None and num_beams > 1:
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
else:
if top_k is not None or top_p is not None:
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
decoding_algorithm = GreedyAlgorithm()
if num_beams > 1:
inputs = torch.cat([inputs] * num_beams, dim=0)
if batch_size > 1:
# TODO: resolve padding problem
logger.warning(
f"You set batch_size {batch_size} within beam search generation. "
f"Be careful, results on sequences with different length may be padded wrong way"
)
if num_return_sequences is None:
num_return_sequences = 1
assert num_return_sequences <= num_beams, (
f"You want more sequences than the beam has."
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
)
constraints = self._get_constraints(
inputs=inputs,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
provided_constraints=provided_constraints,
)
if session is None:
context_manager = self.inference_session(max_length=max_length)
else:
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
with context_manager as session:
# Prepend the tokens from the previous .generate() call
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
if n_prev_tokens > 0:
if kwargs.get("num_beams", 1) > 1:
logger.warning(
"Beam search will not work properly in the resumed petals.InferenceSession "
"since intermediate beam entries are lost"
)
if inputs is not None:
inputs = torch.cat([session.output_ids, inputs], dim=1)
else:
inputs = session.output_ids
# Don't actually run all previous tokens through the transformer,
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))
if self._supports_cache_class and "past_key_values" not in kwargs:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(session.position)
kwargs["past_key_values"] = past_key_values
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Save tokens from this .generate() call
session.output_ids = sequences
# Crop the last tokens from the previous call
sequences = sequences[:, n_prev_tokens:].clone()
if isinstance(result, ModelOutput):
result.sequences = sequences
outputs = []
# Find samples with padded inputs.
# They will be changed before all of the samples have right length.
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
else:
result = sequences
return result
@staticmethod
def _fix_generate_kwargs(kwargs: dict):
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
if "max_length" in kwargs and kwargs["max_length"] is None:
del kwargs["max_length"]
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
do_sample = kwargs.get("do_sample")
if isinstance(do_sample, int):
kwargs["do_sample"] = bool(do_sample)
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
outputs += [inputs]
last_token_id = None
seq_idx = outputs[0].size(1)
hypo_ids = torch.arange(outputs[0].size(0))
while True:
hidden_state = self.transformer.word_embeddings(outputs[-1])
intermediate_prompts = None
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
hidden_state = torch.cat([prompts, hidden_state], dim=1)
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
hidden_state = self.transformer.ln_f(hidden_state)
lm_logits = self.lm_head(hidden_state)
for constraint in constraints:
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
# If some samples were padded, change only these samples
if seq_idx < inputs.size(1):
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
last_token_id = (~pad_token_mask) * inputs[
:, seq_idx : seq_idx + 1
] + pad_token_mask * last_token_id
# TODO: refactor outputs
if num_beams > 1:
for i in range(len(outputs), 1, -1):
outputs[i - 1] = outputs[i - 1][hypo_ids]
outputs.append(last_token_id)
session.last_token_id = last_token_id
seq_idx += 1
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
break
outputs = torch.cat(outputs, dim=-1)
if resuming_session:
outputs = outputs[:, 1:]
if num_beams > 1:
pre_return_idx = [
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
]
return_idx = torch.cat(pre_return_idx, dim=0)
outputs = outputs[return_idx]
return outputs
def greedy_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
:param input_ids: The input tokens to the model.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=GreedyAlgorithm(),
provided_constraints=provided_constraints,
)
def sample(
self,
input_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
:param: input_ids: The input tokens to the model.
:param: temperature: The temperature to use for sampling.
:param: top_k: The number of samples to use for top_k sampling.
:param: top_p: The probability of using top_p sampling.
:param: max_length: The maximum length of the sequence to generate.
:param: pad_token_id: The id of the padding token.
:param: eos_token_id: The id of the end of sentence token.
:param: provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
provided_constraints=provided_constraints,
)
def beam_search(
self,
input_ids: torch.LongTensor,
num_beams: int = 1,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses beam search.
:param input_ids: The input tokens to the model.
:param num_beams: The number of beams to use.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
decoding_algorithm = BeamSearchAlgorithm(
num_beams=num_beams,
batch_size=input_ids.size(0),
)
return self.generate(
inputs=input_ids,
num_beams=num_beams,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=decoding_algorithm,
provided_constraints=provided_constraints,
)
def beam_sample(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def group_beam_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def _choose_sample_algorithm(
self,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> DecodingAlgorithm:
if (top_k is not None) and (top_p is not None):
raise ValueError("You have to provide only top_k or top_p for sampling")
if top_k is not None:
return TopKAlgorithm(top_k, temperature)
elif top_p is not None:
return NucleusAlgorithm(top_p, temperature)
else:
return SamplingAlgorithm(temperature)
def _get_constraints(
self,
inputs: Optional[torch.Tensor] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> List[ABCBloomConstraint]:
constraints = []
constraints.extend(provided_constraints)
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
return constraints

@ -1,18 +1,16 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union
import torch
from hivemind import DHT, get_logger
from torch import nn
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.routing import RemoteSequenceManager
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
logger = get_logger(__name__)
@ -24,7 +22,7 @@ class RemoteSequential(nn.Module):
def __init__(
self,
config: ClientConfig,
config: SequenceManagerConfig,
*,
sequence_manager: Optional[RemoteSequenceManager] = None,
dht: Optional[DHT] = None,
@ -47,52 +45,11 @@ class RemoteSequential(nn.Module):
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
if self.active_session is None:
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
else:
return self.active_session.step(inputs, prompts, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:
"""
If called inside `with model.inference_session(...):` or `with model.use_session(...):`,
returns an active InferenceSession. Otherwise, returns None.
"""
return self._active_session.get()
@property
def position(self) -> int:
"""Returns the prefix length (in tokens) in the active inference session or zero if no session is active."""
return self.active_session.position if self.active_session is not None else 0
@contextmanager
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
token = self._active_session.set(session)
try:
yield session
finally:
self._active_session.reset(token)
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:
"""
Inside this context, forward() will use a _new_ InferenceSession created with given parameters.
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):
yield session
assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return outputs
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
return RemoteSequential(
@ -107,5 +64,8 @@ class RemoteSequential(nn.Module):
def __len__(self):
return len(self.sequence_manager)
def inference_session(self, **kwargs) -> InferenceSession:
return InferenceSession(self.sequence_manager, **kwargs)
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -1,2 +1 @@
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
"""Client-side functions responsible for choosing the best server, """

@ -1,15 +1,17 @@
import dataclasses
import time
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
logger = get_logger(__name__)
T = TypeVar("T")
@dataclasses.dataclass
class RemoteSequenceInfo:
"""
@ -28,7 +30,7 @@ class RemoteSequenceInfo:
last_updated_time: Optional[float]
@classmethod
def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
block_uids = tuple(block_uids)
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
empty_spans = tuple([] for _ in range(len(block_uids)))
@ -37,7 +39,7 @@ class RemoteSequenceInfo:
def __getitem__(self, ix: slice):
assert isinstance(ix, slice)
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
return RemoteSequenceInfo(
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
)
@ -45,23 +47,60 @@ class RemoteSequenceInfo:
def __len__(self):
return len(self.block_uids)
def update_(self, new_block_infos: List[RemoteModuleInfo]):
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
if info is None:
logger.debug(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
continue
if not info.servers:
logger.debug(f"Found no active peers for block {uid}")
continue
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
continue
self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
self.last_updated_time = time.perf_counter()
@staticmethod
def _sort_spans(block_infos: List[RemoteModuleInfo]):
spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
spans_by_priority.sort(key=lambda span: span.length, reverse=True)
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.length, reverse=True)
spans_containing_block = tuple([] for _ in range(len(block_infos)))
for span in spans_by_priority:
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return spans_by_priority, spans_containing_block
return closed_spans, spans_containing_block

@ -7,8 +7,7 @@ import logging
import random
import threading
import time
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
from weakref import WeakMethod
import dijkstar
@ -19,27 +18,39 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
from petals.client.config import ClientConfig
import petals.dht_utils
from petals.client.routing.sequence_info import RemoteSequenceInfo
from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
from petals.utils.dht import get_remote_module_infos
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
logger = get_logger(__name__)
class SequenceManagerConfig(ClientConfig):
def __init__(self, *args, **kwargs):
warnings.warn(
"petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. "
"This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
@dataclasses.dataclass
class SequenceManagerConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
use_server_to_server: bool = True # Use direct server-to-server communication
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 5 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
@dataclasses.dataclass
@ -70,7 +81,7 @@ class RemoteSequenceManager:
def __init__(
self,
config: ClientConfig,
config: SequenceManagerConfig,
block_uids: Sequence[ModuleUID],
*,
dht: Optional[DHT] = None,
@ -104,9 +115,6 @@ class RemoteSequenceManager:
self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy()
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None:
@ -117,23 +125,7 @@ class RemoteSequenceManager:
if state.sequence_info.last_updated_time is not None:
assert block_uids == state.sequence_info.block_uids
self._thread.ready.set() # no need to await the first dht fetch
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
if peer_ids is None:
return None
result = set()
for peer_id in peer_ids:
if isinstance(peer_id, PeerID):
result.add(peer_id)
elif isinstance(peer_id, str):
result.add(PeerID.from_base58(peer_id))
else:
raise TypeError(
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
)
return result
self._need_latest_infos = True
def make_sequence(
self,
@ -300,8 +292,6 @@ class RemoteSequenceManager:
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
client_server_rtts = self.ping_aggregator.to_dict()
span_sequence = []
current_index = start_index
while current_index < end_index:
@ -309,13 +299,7 @@ class RemoteSequenceManager:
if not candidate_spans:
raise MissingBlocksError(current_index)
# We choose longer servers to minimize the number of hops but leave some randomization
# to distribute the load. We also exclude servers known to be unreachable.
eps = 1e-6
span_weights = np.array(
[span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
dtype=np.float64,
)
span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
assert chosen_span.start <= current_index < chosen_span.end
@ -340,18 +324,21 @@ class RemoteSequenceManager:
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
new_block_infos = get_remote_module_infos(
new_block_infos = petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
)
for block_info in new_block_infos:
# Apply allow and block lists
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if (self.allowed_servers is None or peer_id in self.allowed_servers)
and (self.blocked_servers is None or peer_id not in self.blocked_servers)
}
if not block_info:
continue
# Apply whitelist, if defined
if self.config.allowed_servers is not None:
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
}
# Remove temporarily banned peers, unless there are no peers left
valid_servers = {
@ -373,13 +360,9 @@ class RemoteSequenceManager:
self.state.sequence_info.update_(new_block_infos)
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
middle_servers = [
span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
]
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
@ -470,21 +453,14 @@ class RemoteSequenceManager:
return 0
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
def get_request_metadata(
self, protocol: str, args_structure: Any = None, *args, **kwargs
) -> Optional[Dict[str, Any]]:
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
:param args: request-specific inputs, typically block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(
points=self.policy.get_points(protocol, *args, **kwargs),
active_adapter=self.config.active_adapter,
args_structure=args_structure,
)
return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
def shutdown(self):
self._thread.shutdown()
@ -537,7 +513,7 @@ class MissingBlocksError(RuntimeError):
def __init__(self, block_indices: Union[int, Sequence[int]]):
super().__init__(
f"No servers holding blocks {block_indices} are online. "
f"You can check the public swarm's state at https://health.petals.dev "
f"You can check the public swarm's state at http://health.petals.ml "
f"If there are not enough servers, please connect your GPU: "
f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity "
)

@ -12,11 +12,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -68,18 +67,16 @@ async def sequential_forward(
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata(
"rpc_forward", args_structure, span_uids, *flat_tensors
)
metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
(outputs,) = await run_remote_forward(
span_uids,
stub,
sequence_manager.rpc_info,
*flat_tensors,
config=sequence_manager.config,
*inputs_and_prompts,
timeout=sequence_manager.config.request_timeout,
metadata=MSGPackSerializer.dumps(metadata),
)
@ -152,22 +149,19 @@ async def sequential_backward(
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
flat_tensors, args_structure = pack_args_kwargs(
inputs, *grad_outputs_cpu, prompts[span.start : span.end]
)
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
"rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
stub,
sequence_manager.rpc_info,
*flat_tensors,
config=sequence_manager.config,
inputs,
grad_outputs,
prompts[span.start : span.end],
timeout=sequence_manager.config.request_timeout,
metadata=MSGPackSerializer.dumps(metadata),
)
grad_outputs = [grad_outputs]
@ -230,7 +224,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
if prompts is None or is_dummy(prompts):
if is_dummy(prompts):
prompt_batches = [DUMMY] * len(input_batches)
else:
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)

@ -1,18 +1,13 @@
import torch
PUBLIC_INITIAL_PEERS = [
# IPv4 DNS addresses
"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
# IPv6 DNS addresses
"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
# Reserved IPs
"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
]
# The reachability API is currently used only when connecting to the public swarm
REACHABILITY_API_URL = "https://health.petals.dev"
REACHABILITY_API_URL = "http://health.petals.ml"
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

@ -6,30 +6,13 @@ import pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
from petals.server.memory_cache import Handle
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
dht_prefix, index = uid.split(UID_DELIMITER)
return dht_prefix, int(index)
@pydantic.dataclasses.dataclass
class ModelInfo:
num_blocks: pydantic.conint(ge=1, strict=True)
repository: Optional[str] = None
def to_dict(self) -> dict:
return dataclasses.asdict(self)
@classmethod
def from_dict(cls, source: dict):
return cls(**source)
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
@ -44,9 +27,6 @@ class ServerInfo:
state: ServerState
throughput: RPS
start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
public_name: Optional[str] = None
version: Optional[str] = None
@ -92,22 +72,12 @@ class RemoteSpanInfo:
server_info: ServerInfo
@property
def length(self) -> int:
def length(self):
return self.end - self.start
@property
def state(self) -> ServerState:
return self.server_info.state
@property
def throughput(self) -> float:
return self.server_info.throughput
RPCInfo = Dict[str, Any]
Handle = int
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:

@ -1,9 +1,124 @@
import warnings
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
warnings.warn(
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from petals.utils.dht import *
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules

@ -1,4 +1,2 @@
from petals.models.bloom import *
from petals.models.falcon import *
from petals.models.llama import *
from petals.models.mixtral import *

@ -6,11 +6,8 @@ See commit history for authorship.
from typing import Optional, Tuple
import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
from petals.utils.misc import is_dummy
class WrappedBloomBlock(BloomBlock):
def forward(
@ -24,22 +21,12 @@ class WrappedBloomBlock(BloomBlock):
):
assert attention_mask is None, "Non-causal attention masks are not supported yet"
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None and is_dummy(layer_past[0]):
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
# In this case, fallback to the old code:
layer_past = None
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)
attention_mask = attention_mask.bool()
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.bloom.block import WrappedBloomBlock
logger = get_logger(__name__)
class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedBloomBlock
attn_class = BloomAttention
block_prefix = "h"
@ -30,6 +30,5 @@ class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfi
if loading_from_repo and dht_prefix is None:
# We need "-petals" for backward compatibility with Petals < 1.2.0
dht_prefix = str(model_name_or_path) + "-petals"
dht_prefix = dht_prefix.replace(".", "-")
logger.info(f"Using DHT prefix: {dht_prefix}")
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)

@ -4,14 +4,13 @@ import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_sequential import RemoteSequential
from petals.models.bloom.config import DistributedBloomConfig
@ -40,15 +39,16 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -59,50 +59,32 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
if use_prompts:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.h(hidden_states)
# Remove prefix
if use_prompts:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = hidden_states[:, self.pre_seq_len :]
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
past_key_values=None,
hidden_states=None,
attentions=None,
)
@ -112,7 +94,6 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
_supports_cache_class = True
config_class = DistributedBloomConfig
@ -124,58 +105,6 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
# Initialize weights and apply final processing
self.post_init()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
) -> dict:
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _temporary_reorder_cache(self, past_key_values, beam_idx):
return past_key_values
def get_output_embeddings(self):
return self.lm_head

@ -1,15 +0,0 @@
from petals.models.falcon.block import WrappedFalconBlock
from petals.models.falcon.config import DistributedFalconConfig
from petals.models.falcon.model import (
DistributedFalconForCausalLM,
DistributedFalconForSequenceClassification,
DistributedFalconModel,
)
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedFalconConfig,
model=DistributedFalconModel,
model_for_causal_lm=DistributedFalconForCausalLM,
model_for_sequence_classification=DistributedFalconForSequenceClassification,
)

@ -1,480 +0,0 @@
"""
Falcon intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
See commit history for authorship.
"""
import math
from functools import partial
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconConfig,
FalconDecoderLayer,
FalconLinear,
FalconMLP,
FalconModel,
LayerNorm,
build_alibi_tensor,
dropout_add,
rotate_half,
)
KVCache = Tuple[torch.Tensor, torch.Tensor]
INFERENCE_MAX_LENGTH = 8192
def apply_rotary(query, key, cos, sin):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
class OptimizedFalconRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = -1
self.cuda_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_apply_rotary(self, query, key, cos, sin):
if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph()
self.input_surface = (query, key, cos, sin)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
apply_rotary(*self.input_surface)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.cuda_graph):
self.static_outputs = apply_rotary(*self.input_surface)
inputs = (query, key, cos, sin)
for static_input, data in zip(self.input_surface, inputs):
static_input.copy_(data)
self.cuda_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if self.seq_len_cached == -1:
# warm up the cache
total_length = max(INFERENCE_MAX_LENGTH, total_length)
if total_length > self.seq_len_cached:
with torch.inference_mode(False):
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
)
def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
return self._optimized_apply_rotary(query, key, cos, sin)
else:
return apply_rotary(query, key, cos, sin)
def split_heads(
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
key = torch.broadcast_to(key, query.shape)
value = torch.broadcast_to(value, query.shape)
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
class OptimizedFalconAttention(FalconAttention):
def __init__(self, config: FalconConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = self.inv_norm_factor
if config.new_decoder_architecture:
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
elif config.multi_query:
qkv_out_dim = self.hidden_size + 2 * self.head_dim
else:
qkv_out_dim = 3 * self.hidden_size
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
if self.new_decoder_architecture:
self._split_heads = partial(
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
)
self.split_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_split_heads(self, fused_qkv):
if self.split_graph is None:
self.split_graph = torch.cuda.CUDAGraph()
self.input_surface = fused_qkv
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self._split_heads(fused_qkv)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.split_graph):
self.static_outputs = self._split_heads(self.input_surface)
self.input_surface.copy_(fused_qkv)
self.split_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
assert not output_attentions
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
if (
self.new_decoder_architecture
and hidden_states.size(1) == 1
and torch.is_inference_mode_enabled()
and hidden_states.device.type == "cuda"
):
query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
else:
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
num_kv_heads = self.num_heads
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
if alibi is None:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
)
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
return output_tensor, present
else:
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
output_tensor = self.dense(context_layer)
if output_attentions:
return output_tensor, present, attention_probs
else:
return output_tensor, present
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
def __init__(self, config: FalconConfig):
nn.Module.__init__(self)
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
self.self_attention = OptimizedFalconAttention(config)
if self.config.alibi or not config.new_decoder_architecture:
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_graph = None
self.static_input = None
self.static_outputs = None
def _optimized_apply_ln(self, hidden_states):
if self.ln_graph is None:
self.ln_graph = torch.cuda.CUDAGraph()
self.static_input = hidden_states
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.ln_attn(hidden_states)
self.ln_mlp(hidden_states)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.ln_graph):
ln_attn_output = self.ln_attn(hidden_states)
ln_mlp_output = self.ln_mlp(hidden_states)
self.static_outputs = (ln_attn_output, ln_mlp_output)
self.static_input.copy_(hidden_states)
self.ln_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
if self.config.new_decoder_architecture:
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
else:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
attn_outputs = self.self_attention(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual = dropout_add(
attention_output, residual, self.config.attention_dropout, training=self.training
)
mlp_layernorm_out = self.post_attention_layernorm(residual)
outputs = attn_outputs[1:]
mlp_output = self.mlp(mlp_layernorm_out)
if self.config.new_decoder_architecture or self.config.parallel_attn:
mlp_output += attention_output
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
class WrappedFalconBlock(OptimizedFalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs,
):
assert attention_mask is None
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None:
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
past_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs,
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
if self.config.new_decoder_architecture:
key_states = self._expand_states(key_states)
value_states = self._expand_states(value_states)
return (key_states, value_states)
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
if self.config.new_decoder_architecture:
key_states = self._collapse_states(key_states)
value_states = self._collapse_states(value_states)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
return state
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state

@ -1,48 +0,0 @@
import os
from typing import Optional, Union
from hivemind import get_logger
from transformers.models.falcon import FalconConfig
from transformers.models.falcon.modeling_falcon import FalconAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.models.falcon.block import WrappedFalconBlock
from petals.utils.auto_config import DefaultRevisionMixin
logger = get_logger(__name__)
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedFalconBlock
attn_class = FalconAttention
block_prefix = "transformer.h"
@property
def num_key_value_groups(self) -> int:
if self.new_decoder_architecture:
return self.num_attention_heads // self.num_kv_heads
if self.multi_query:
return self.num_attention_heads
return 1
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
if "180B" in model_name_or_path.upper():
logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license")
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
dht_prefix = str(model_name_or_path)
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
dht_prefix = dht_prefix.replace(".", "-")
logger.info(f"Using DHT prefix: {dht_prefix}")
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
config = result[0] if isinstance(result, tuple) else result
if config.pad_token_id is None:
config.pad_token_id = 0
return result

@ -1,154 +0,0 @@
from typing import Optional
import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon import (
FalconForCausalLM,
FalconForSequenceClassification,
FalconModel,
FalconPreTrainedModel,
)
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.falcon.config import DistributedFalconConfig
from petals.utils.auto_config import DefaultRevisionMixin
logger = get_logger(__name__)
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
"""FalconModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
config_class = DistributedFalconConfig
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.h) == 0
config.num_hidden_layers = n_layer
self.h = RemoteSequential(config, dht=dht)
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
if use_prompts:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)
@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return nn.Identity()
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
config_class = DistributedFalconConfig
def __init__(self, config: DistributedFalconConfig):
FalconPreTrainedModel.__init__(self, config)
self.transformer = DistributedFalconModel(config)
self.lm_head = LMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
class DistributedFalconForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
config_class = DistributedFalconConfig
def __init__(self, config: DistributedFalconConfig):
FalconPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
self.transformer = DistributedFalconModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()

@ -3,229 +3,13 @@ LLaMA intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
See commit history for authorship.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
repeat_kv,
rotate_half,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from petals.utils.cuda_graphs import make_inference_graphed_callable
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class OptimizedLlamaAttention(LlamaAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._rotary_graph = None
def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
if self._rotary_graph is None:
self._rotary_graph = make_inference_graphed_callable(
apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
)
return self._rotary_graph(query_states, key_states, cos, sin)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
if position_ids is None:
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
position_ids = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
).unsqueeze(0)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_attn_graph = None
self.post_attn_graph = None
def _optimized_input_layernorm(self, hidden_states):
if self.pre_attn_graph is None:
self.pre_attn_graph = make_inference_graphed_callable(
self.input_layernorm.forward, sample_args=(hidden_states,)
)
return self.pre_attn_graph(hidden_states)
def _optimized_output_layernorm(self, hidden_states):
if self.post_attn_graph is None:
self.post_attn_graph = make_inference_graphed_callable(
self.post_attention_layernorm.forward, sample_args=(hidden_states,)
)
return self.post_attn_graph(hidden_states)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
hidden_states = self._optimized_input_layernorm(hidden_states)
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
hidden_states = self._optimized_output_layernorm(hidden_states)
else:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
class WrappedLlamaBlock(LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
@ -247,18 +31,22 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
assert position_ids is None
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.llama.block import WrappedLlamaBlock
logger = get_logger(__name__)
class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedLlamaBlock
attn_class = LlamaAttention
block_prefix = "model.layers"
@ -35,7 +35,6 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
if loading_from_repo and dht_prefix is None:
dht_prefix = str(model_name_or_path)
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
dht_prefix = dht_prefix.replace(".", "-")
if not dht_prefix.endswith("-hf"):
dht_prefix += "-hf"
logger.info(f"Using DHT prefix: {dht_prefix}")
@ -43,5 +42,4 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
config = result[0] if isinstance(result, tuple) else result
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
config.use_cache = True # use_cache=False leads to identical results but is slower and not supported by Petals
return result

@ -10,7 +10,7 @@ from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_sequential import RemoteSequential
from petals.models.llama.config import DistributedLlamaConfig
@ -39,16 +39,16 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> BaseModelOutputWithPast:
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -59,55 +59,32 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
if cache_position is not None:
assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0
if use_prompts:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.layers(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.layers(hidden_states)
# Remove prefix
if use_prompts:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
past_key_values=None,
hidden_states=None,
attentions=None,
)

@ -1,15 +0,0 @@
from petals.models.mixtral.block import WrappedMixtralBlock
from petals.models.mixtral.config import DistributedMixtralConfig
from petals.models.mixtral.model import (
DistributedMixtralForCausalLM,
DistributedMixtralForSequenceClassification,
DistributedMixtralModel,
)
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedMixtralConfig,
model=DistributedMixtralModel,
model_for_causal_lm=DistributedMixtralForCausalLM,
model_for_sequence_classification=DistributedMixtralForSequenceClassification,
)

@ -1,114 +0,0 @@
import json
from typing import Optional, Tuple
import torch
from transformers import MixtralConfig
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
class WrappedMixtralBlock(MixtralDecoderLayer):
def __init__(self, config: MixtralConfig, layer_idx: int):
super().__init__(config, layer_idx)
self._attn_implementation = config._attn_implementation
self.sliding_window = config.sliding_window
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
**kwargs
):
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
past_key_value = layer_past
if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
past_key_value = DynamicCache()
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
past_key_value._seen_tokens = past_key_values_length
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa":
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=self.sliding_window,
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
**kwargs
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = present_key_value[self.layer_idx]
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
# TODO: Move to mixin
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
def _reorder_cache_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
# TODO: Move to mixin
key_states, value_states = key_value
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
key_states = key_states.view(*value_states.shape)
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)

@ -1,36 +0,0 @@
import os
from typing import Optional, Union
from hivemind import get_logger
from transformers.models.mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.models.mixtral.block import WrappedMixtralBlock
logger = get_logger(__name__)
class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedMixtralBlock
attn_class = MixtralAttention
block_prefix = "model.layers"
num_key_value_groups = 1
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
dht_prefix = str(model_name_or_path)
dht_prefix = dht_prefix.replace(".", "-")
logger.info(f"Using DHT prefix: {dht_prefix}")
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
config = result[0] if isinstance(result, tuple) else result
if config.pad_token_id is None:
config.pad_token_id = 0
return result

@ -1,178 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
from hivemind import DHT
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralModel,
MixtralPreTrainedModel,
)
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.mixtral.config import DistributedMixtralConfig
from petals.utils.auto_config import DefaultRevisionMixin
logger = get_logger(__name__)
class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel):
"""MixtralModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
config_class = DistributedMixtralConfig
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.layers) == 0
config.num_hidden_layers = n_layer
self.layers = RemoteSequential(config, dht=dht)
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
assert not output_router_logits, f"{output_router_logits=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
if use_prompts:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
hidden_states = self.layers(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(output_shape)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
@property
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens
@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return nn.Identity()
@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers
@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return self.norm
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
config_class = DistributedMixtralConfig
def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.model = DistributedMixtralModel(config)
self.lm_head = LMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
@property
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
return self.model
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
config_class = DistributedMixtralConfig
def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
self.model = DistributedMixtralModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@property
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
return self.model

@ -16,7 +16,7 @@ from transformers import PretrainedConfig
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import get_size_in_bytes, is_dummy
from petals.utils.misc import is_dummy
logger = get_logger(__name__)
@ -27,13 +27,7 @@ class TransformerBackend(ModuleBackend):
_peft_module = None
def __init__(
self,
*args,
config: PretrainedConfig,
memory_cache: MemoryCache,
backend_dtype: torch.dtype,
max_chunk_size_bytes: int,
**kwargs,
self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
):
import petals.utils.peft as _peft_module
@ -43,8 +37,6 @@ class TransformerBackend(ModuleBackend):
assert isinstance(self.module, TensorParallel)
self.config = config
self.memory_cache = memory_cache
self.max_chunk_size_bytes = max_chunk_size_bytes
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
for name, buf in self.module.named_buffers():
@ -63,7 +55,6 @@ class TransformerBackend(ModuleBackend):
)
self.dtype = backend_dtype
self.dtype_bytes = get_size_in_bytes(self.dtype)
self.shard_num_heads = []
for shard in self.module.module_shards:
for submodule in shard.modules():
@ -83,7 +74,7 @@ class TransformerBackend(ModuleBackend):
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""
@ -91,8 +82,6 @@ class TransformerBackend(ModuleBackend):
cache_tensors = []
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
num_heads //= self.config.num_key_value_groups
if hasattr(self.config, "num_key_value_heads"):
num_heads = self.config.num_key_value_heads
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
cache_tensors.extend((keys, values))
@ -116,40 +105,14 @@ class TransformerBackend(ModuleBackend):
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
seq_len = hidden_states.shape[1]
with self.memory_cache.use_cache(
*inference_info.cache_handles
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
self._reorder_cache_inplace(cache_tensors, hypo_ids)
# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
# reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`
# is at least 4-6x less than `autograd_memory`.
max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)
output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
for offset in range(0, seq_len, max_chunk_length):
hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
output_hidden_states_chunk, new_kvs = self.module.forward(
hidden_states_chunk, layer_past=layer_past, use_cache=True
)
if seq_len > max_chunk_length:
output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
else:
output_hidden_states = output_hidden_states_chunk # saves one memcopy
layer_past = new_kvs
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (output_hidden_states,)
def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int:
# We assume that attention logit matrices are the main thing that consumes memory, given that
# the model uses multi-query attention
batch_size, seq_length, hidden_size = hidden_states.shape
worst_case_length = inference_info.prefix_length + seq_length
attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length
return max(1, self.max_chunk_size_bytes // attn_bytes_per_token)
return (hidden_states,)
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""

@ -1,230 +0,0 @@
"""
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
"""
from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
import torch
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
from hivemind.utils.nested import nested_flatten
from petals.data_structures import Handle, InferenceMetadata
from petals.server.backend import TransformerBackend
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import unpack_args_kwargs
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
MAX_SHORT_INFERENCE_TOKENS = 128
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
logger = get_logger(__name__)
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, *_ = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def run_rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
inputs, grad_outputs, prompts, *_ = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
async def iterate_rpc_inference(
requested_uids: Sequence[ExpertUID],
requested_backends: Sequence[TransformerBackend],
active_adapter: Optional[str],
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
cache_handles: Sequence[Sequence[Handle]],
*,
max_length: int,
prioritizer: TaskPrioritizerBase,
points: int,
quant_type: QuantType,
args_structure: Any = None,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
assert len(cache_handles) == len(requested_backends)
prefix_length = 0
point_per_piece = points / max_length if max_length > 0 else 0.0
async for request, step_metadata in input_iterator:
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, hypo_ids, *_ = flat_tensors
batch_size, length_increment, _ = hidden_states.shape
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
can_merge_pools = batch_size * length_increment <= merge_max_tokens
priority = prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
if hidden_states.numel() > 0:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
if can_merge_pools:
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
else:
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
can_push = not has_prompts
yield output_tensors, can_push, step_metadata
# prepare for next step
prefix_length += length_increment

@ -1,23 +1,54 @@
from typing import Dict, List
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
from hivemind import PeerID, get_logger
from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
from petals.data_structures import RemoteModuleInfo, ServerState
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
logger = get_logger(__name__)
def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
@dataclass
class Span:
start: int
end: int
throughput: float
state: ServerState
@property
def length(self):
return self.end - self.start
def move_to(self, new_start: int) -> None:
self.start, self.end = new_start, new_start + self.length
def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
spans = {}
throughputs = np.zeros(len(module_infos))
for block, module in enumerate(module_infos):
if module is None:
continue
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
for peer_id, server in sorted(module.servers.items()):
if server.state == ServerState.OFFLINE:
continue
throughputs = np.zeros(total_blocks)
for span in sorted(spans.values(), key=lambda span: span.peer_id):
throughputs[span.start : span.end] += span.throughput
return throughputs
if peer_id in spans:
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].start, block + 1)
else:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
throughputs[block] += server.throughput
return spans, throughputs
def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
@ -25,26 +56,19 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
return min(options)[-1]
def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
_, throughputs = compute_spans(module_infos)
start = _choose_best_start(throughputs, num_blocks)
return list(range(start, start + num_blocks))
def _move_span(span: RemoteSpanInfo, new_start: int):
span.start, span.end = new_start, new_start + span.length
def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
) -> bool:
if balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
spans, throughputs = compute_spans(module_infos)
initial_throughput = throughputs.min()
eps = 1e-3
@ -64,7 +88,7 @@ def should_choose_other_blocks(
return False # This server is on its best place already
throughputs[local_span.start : local_span.end] += local_span.throughput * eps
_move_span(local_span, new_start)
local_span.move_to(new_start)
throughputs[local_span.start : local_span.end] += local_span.throughput
moved = True
@ -81,7 +105,7 @@ def should_choose_other_blocks(
throughputs[span.start : span.end] += span.throughput * eps
if span.start != new_start:
_move_span(span, new_start)
span.move_to(new_start)
moved = True
throughputs[span.start : span.end] += span.throughput

@ -2,19 +2,16 @@ from typing import Optional, Union
import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig
from petals.models.mixtral.block import WrappedMixtralBlock
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
if dtype not in ("auto", None):
return dtype
if config.torch_dtype not in ("auto", None, torch.float32):
# If config specifies float32, we override it to the default dtype below
if config.torch_dtype not in ("auto", None):
return config.torch_dtype
return torch.bfloat16
@ -33,13 +30,13 @@ def get_block_size(
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
with init_empty_weights(include_buffers=True):
block = get_model_block(config)
block = config.block_class(config)
n_params = sum(param.numel() for param in block.parameters())
if location == "memory":
if quant_type == QuantType.NONE:
dtype = resolve_block_dtype(config, dtype)
bytes_per_value = get_size_in_bytes(dtype)
bytes_per_value = torch.finfo(dtype).bits // 8
elif quant_type == QuantType.INT8:
bytes_per_value = 1
elif quant_type == QuantType.NF4:
@ -48,18 +45,6 @@ def get_block_size(
raise ValueError(f"Unsupported quant_type={quant_type}")
elif location == "disk":
dtype = resolve_block_dtype(config, "auto")
bytes_per_value = get_size_in_bytes(dtype)
bytes_per_value = torch.finfo(dtype).bits // 8
return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, layer_idx: int = 0):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, layer_idx)
return config.block_class(config)

@ -8,23 +8,19 @@ If necessary, one can rewrite this to implement a different behavior, such as:
"""
import json
import time
from contextlib import suppress
from typing import Dict, Optional, Union
import safetensors
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from hivemind.utils.logging import get_logger
from huggingface_hub import get_hf_file_metadata, hf_hub_url
from huggingface_hub.utils import EntryNotFoundError
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig
from transformers.utils import get_file_from_repo
from petals.constants import DTYPE_MAP
from petals.models.mixtral import WrappedMixtralBlock
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.server.block_utils import resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth
@ -44,7 +40,7 @@ def load_pretrained_block(
max_disk_space: Optional[int] = None,
) -> nn.Module:
if config is None:
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
config = AutoDistributedConfig.from_pretrained(model_name, token=token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
@ -52,7 +48,7 @@ def load_pretrained_block(
torch_dtype = resolve_block_dtype(config, torch_dtype)
with init_empty_weights():
block = get_model_block(config, layer_idx=block_index)
block = config.block_class(config)
block_prefix = f"{config.block_prefix}.{block_index}."
state_dict = _load_state_dict_from_repo(
@ -65,7 +61,7 @@ def load_pretrained_block(
)
# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=False)
report = block.load_state_dict(state_dict, strict=True)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters():
@ -75,8 +71,7 @@ def load_pretrained_block(
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}")
logger.info(f"Loaded {model_name} block {block_index}, {report}")
return block
@ -95,14 +90,11 @@ def _load_state_dict_from_repo(
if always_needs_auth(model_name) and token is None:
token = True
index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)
if index_file.endswith(".index.json"): # Sharded model
path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)
if path is None:
# _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)
raise ValueError(f"Failed to get file {index_file}")
with open(path) as f:
index_file = get_file_from_repo(
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
)
if index_file is not None: # Sharded model
with open(index_file) as f:
index = json.load(f)
filenames = {
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
@ -110,15 +102,14 @@ def _load_state_dict_from_repo(
if not filenames:
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
else: # Non-sharded model
filenames = {index_file}
filenames = {"pytorch_model.bin"}
logger.debug(f"Loading {block_prefix}* from {filenames}")
state_dict = {}
for filename in filenames:
shard_state_dict = _load_state_dict_from_repo_file(
shard_state_dict = _load_state_dict_from_file(
model_name,
filename,
block_prefix=block_prefix,
revision=revision,
token=token,
cache_dir=cache_dir,
@ -133,42 +124,10 @@ def _load_state_dict_from_repo(
return state_dict
INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"]
def _find_index_file(
model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str
) -> str:
# If we have cached weights (e.g., Pickle from older Petals versions), reuse them
for filename in INDEX_FILES:
path = get_file_from_repo(
model_name,
filename,
revision=revision,
use_auth_token=token,
cache_dir=cache_dir,
local_files_only=True,
)
if path is not None:
return filename
# If we don't, prefer Safetensors when possible
# (we don't download files here since we can't account for max_disk_space in case of large files)
for filename in INDEX_FILES:
with suppress(EntryNotFoundError):
get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)
return filename
raise ValueError(
f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist"
)
def _load_state_dict_from_repo_file(
def _load_state_dict_from_file(
model_name: str,
filename: str,
*,
block_prefix: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: str,
@ -187,7 +146,7 @@ def _load_state_dict_from_repo_file(
local_files_only=True,
)
if path is not None:
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
return torch.load(path, map_location="cpu")
except Exception:
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
@ -212,18 +171,7 @@ def _load_state_dict_from_repo_file(
)
if path is None:
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
return torch.load(path, map_location="cpu")
except Exception as e:
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
time.sleep(delay)
def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:
if path.endswith(".bin"):
return torch.load(path, map_location="cpu")
if path.endswith(".safetensors"):
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}
raise ValueError(f"Unknown weight format: {path}")

@ -6,7 +6,7 @@ import multiprocessing as mp
import sys
from enum import Enum
from itertools import chain
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
from async_timeout import timeout
@ -29,11 +29,12 @@ from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__name__)
@ -71,7 +72,6 @@ class TransformerConnectionHandler(ConnectionHandler):
session_timeout: float,
step_timeout: float,
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
quant_type: QuantType,
):
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
@ -89,7 +89,6 @@ class TransformerConnectionHandler(ConnectionHandler):
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self._prioritizer = task_prioritizer
self.quant_type = quant_type
async def add_p2p_handlers(self, *args, **kwargs) -> None:
if self._listener_task is None:
@ -148,10 +147,9 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
args_structure = metadata.get("args_structure")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
@ -165,32 +163,78 @@ class TransformerConnectionHandler(ConnectionHandler):
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
)
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
prefix_length = 0
async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
first_request = request
background_tasks = set()
async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
input_iterator=self._iterate_inference_steps(
request, requests, session_id, requested_uids, context
),
cache_handles=cache_handles,
max_length=max_length,
prioritizer=self._prioritizer,
points=points,
quant_type=self.quant_type,
args_structure=args_structure,
async for request, metadata in self._iterate_inference_steps(
first_request, requests, session_id, requested_uids, context
):
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
priority = self._prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
if not has_prompts:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
# prepare for next step
prefix_length += length_increment
finally:
self._log_request("rpc_inference.close", requested_uids, context)
@ -303,7 +347,7 @@ class TransformerConnectionHandler(ConnectionHandler):
anext_task.cancel()
get_push_task.cancel()
return
except Exception:
except:
logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
raise
@ -360,18 +404,16 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await run_rpc_forward(
hidden_states = await _rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@ -389,18 +431,16 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await run_rpc_forward(
hidden_states = await _rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_output for streaming and respond to client
@ -442,18 +482,16 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await run_rpc_backward(
grads = await _rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@ -469,18 +507,16 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await run_rpc_backward(
grads = await _rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):
@ -531,19 +567,14 @@ class TransformerConnectionHandler(ConnectionHandler):
@contextlib.asynccontextmanager
async def _allocate_cache(
self,
backends: Sequence[TransformerBackend],
*,
batch_size: int,
max_length: int,
timeout: Optional[float],
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
) -> Sequence[Sequence[Handle]]:
"""
Allocate memory cache for all transformer blocks, return cache handle
:returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
"""
descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
yield nested_pack(handles, descriptors)
def _log_request(
@ -590,3 +621,105 @@ class TransformerConnectionHandler(ConnectionHandler):
result.update(block_info)
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
async def _rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def _rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape

@ -12,26 +12,25 @@ import os
import time
from typing import AsyncContextManager, Dict, Optional, Sequence
import async_timeout
import hivemind
import torch
from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
from hivemind.utils import TensorDescriptor, get_logger
from petals.data_structures import Handle
from petals.utils.asyncio import shield_and_wait
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
Handle = int
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.max_alloc_timeout = max_alloc_timeout
self.alloc_timeout = alloc_timeout
self._lock_metadata = mp.Lock()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
self.runtime_pid = os.getpid()
@ -48,14 +47,6 @@ class MemoryCache:
def current_size_bytes(self, value: int):
self._current_size.value = value
@property
def enqueued_size_bytes(self) -> int:
return self._enqueued_size.value
@enqueued_size_bytes.setter
def enqueued_size_bytes(self, value: int):
self._enqueued_size.value = value
@property
def bytes_left(self) -> int:
return self.max_size_bytes - self.current_size_bytes
@ -69,14 +60,11 @@ class MemoryCache:
self._handle_counter.value = value
@contextlib.asynccontextmanager
async def allocate_cache(
self, *descriptors: TensorDescriptor, timeout: float
) -> AsyncContextManager[Sequence[Handle]]:
async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
"""
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
:param descriptors: one or more tensors tensor of this size, dtype, etc
:param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
if not, it will count maximum tensor allocation across devices for the purposes of size limit
@ -86,8 +74,6 @@ class MemoryCache:
"""
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
if self.max_alloc_timeout is not None:
timeout = min(timeout, self.max_alloc_timeout)
max_alloc_size = self.get_allocation_size(*descriptors)
gib = 1024**3
@ -98,80 +84,52 @@ class MemoryCache:
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
)
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
try:
handles = await shield_and_wait(alloc_task)
logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
yield handles
finally:
self._free(max_alloc_size, alloc_task)
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
@staticmethod
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
alloc_size_by_device = {}
for descr in descriptors:
tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
return max(alloc_size_by_device.values())
async def _schedule_alloc(
self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]
) -> Sequence[Handle]:
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
"""
try:
async with self._wait_for_free_memory(alloc_size, timeout):
with self._lock_metadata:
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors))
return handles
except TimeoutError:
raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})")
@contextlib.asynccontextmanager
async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):
start_time = time.perf_counter()
loop = asyncio.get_event_loop()
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors))
return handles
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
- _schedule_free() must finish freeing memory even in case of cancellation
"""
with self._enqueued_size.get_lock():
self._enqueued_size.value += alloc_size
allocated = False
try:
context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
# contextlib.AsyncExitStack() is used as a null context here
async with context_manager:
if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:
raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
async with enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
if timeout == 0:
raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
elapsed_time = time.perf_counter() - start_time
remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None
await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
allocated = True
with self._enqueued_size.get_lock():
self._enqueued_size.value -= alloc_size
yield
except asyncio.TimeoutError:
raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
finally:
if not allocated:
with self._enqueued_size.get_lock():
self._enqueued_size.value -= alloc_size
def _free(self, alloc_size: int, alloc_task: asyncio.Task):
if alloc_task.exception() is not None:
return
handles = alloc_task.result()
with self._lock_metadata:
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pipe_send.send((handles, None)) # signal runtime to free these handles
self.current_size_bytes -= alloc_size
self._memory_freed_event.set()
@ -182,10 +140,9 @@ class MemoryCache:
raise AllocationFailed(
f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
)
timeout = timeout if timeout != float("inf") else None
deadline = None if timeout is None else time.perf_counter() + timeout
while self.current_size_bytes + allocated_size > self.max_size_bytes:
remaining_time = None if timeout is None else deadline - time.perf_counter()
remaining_time = deadline - time.perf_counter() if timeout is not None else None
if not self._memory_freed_event.wait(remaining_time):
raise AllocationFailed(
f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"
@ -203,21 +160,22 @@ class MemoryCache:
assert os.getpid() == self.runtime_pid
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
# read creation/deletion requests from connection handlers
while self._pipe_recv.poll():
recv_handles, recv_data = self._pipe_recv.recv()
if recv_data is not None: # create new tensors
assert len(recv_handles) == len(recv_data)
for handle, descr in zip(recv_handles, recv_data):
self._allocated_tensors[handle] = descr.make_zeros()
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
else: # delete tensors by handle
for handle in recv_handles:
if handle not in self._allocated_tensors:
logger.warning(
f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
)
self._allocated_tensors.pop(handle, None)
with self._lock_metadata:
# read creation/deletion requests from connection handlers
while self._pipe_recv.poll():
recv_handles, recv_data = self._pipe_recv.recv()
if recv_data is not None: # create new tensors
assert len(recv_handles) == len(recv_data)
for handle, descr in zip(recv_handles, recv_data):
self._allocated_tensors[handle] = descr.make_zeros()
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
else: # delete tensors by handle
for handle in recv_handles:
if handle not in self._allocated_tensors:
logger.warning(
f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
)
self._allocated_tensors.pop(handle, None)
yield tuple(self._allocated_tensors[handle] for handle in handles)

@ -28,7 +28,7 @@ def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float
response = r.json()
if response["success"]:
logger.info("Server is reachable from the Internet. It will appear at https://health.petals.dev soon")
logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon")
return
if attempt_no == 0:
@ -37,7 +37,7 @@ def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
time.sleep(retry_delay)
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.dev is down: {repr(e)}")
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return
raise RuntimeError(
@ -140,7 +140,7 @@ class ReachabilityProtocol(ServicerBase):
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
ready.set_result(True)
logger.debug("Reachability service started")
logger.info("Reachability service started")
async with protocol.serve(common_p2p):
await protocol._stop.wait()

@ -3,17 +3,13 @@ from __future__ import annotations
import gc
import math
import multiprocessing as mp
import os
import random
import sys
import threading
import time
from typing import Dict, List, Optional, Sequence, Union
import hivemind
import psutil
import torch
import torch.mps
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
@ -23,7 +19,8 @@ from transformers import PretrainedConfig
import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -34,8 +31,6 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.dht import declare_active_modules, get_remote_module_infos
from petals.utils.misc import get_size_in_bytes
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
from petals.utils.version import get_compatible_model_repo
@ -63,13 +58,12 @@ class Server:
inference_max_length: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: Optional[int] = None,
max_chunk_size_bytes: int = 256 * 1024 * 1024,
max_alloc_timeout: float = 600,
attn_cache_tokens: Optional[int] = None,
torch_dtype: str = "auto",
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
alloc_timeout: float = 5,
device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
@ -83,12 +77,12 @@ class Server:
sender_threads: int = 1,
balance_quality: float = 0.75,
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 5,
mean_block_selection_delay: float = 2.5,
token: Optional[Union[str, bool]] = None,
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
reachable_via_relay: Optional[bool] = None,
dht_client_mode: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
adapters: Sequence[str] = (),
@ -110,7 +104,7 @@ class Server:
self.block_config = AutoDistributedConfig.from_pretrained(
converted_model_name_or_path,
use_auth_token=token,
token=token,
revision=revision,
)
@ -134,20 +128,20 @@ class Server:
for block_index in range(self.block_config.num_hidden_layers)
]
if reachable_via_relay is None:
if dht_client_mode is None:
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer
logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}")
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.num_hidden_layers,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
client_mode=reachable_via_relay,
client_mode=dht_client_mode,
**kwargs,
)
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
@ -158,25 +152,13 @@ class Server:
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device(device.type, index=0)
self.device = device
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
if device.type == "cpu" and torch_dtype == torch.float16:
raise ValueError(
f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
)
if device.type == "mps" and torch_dtype == torch.bfloat16:
logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
torch_dtype = torch.float16
self.torch_dtype = torch_dtype
if tensor_parallel_devices is None:
@ -187,7 +169,10 @@ class Server:
check_device_balance(self.tensor_parallel_devices)
if quant_type is None:
quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE
if device.type == "cuda":
quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
else:
quant_type = QuantType.NONE
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
@ -198,15 +183,13 @@ class Server:
inference_max_length = 8192 if is_multiquery_attn else 2048
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.max_chunk_size_bytes = max_chunk_size_bytes
self.max_alloc_timeout = max_alloc_timeout
# For attention cache in GPU or RAM
if attn_cache_tokens is None:
attn_cache_tokens = 16384 if is_multiquery_attn else 4096
attn_cache_tokens = 32768 if is_multiquery_attn else 2048
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
# For disk cache
self.cache_dir = cache_dir
@ -216,14 +199,13 @@ class Server:
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()
if num_blocks is not None:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
if block_indices is not None:
try:
start_block, end_block = [int(index.strip()) for index in block_indices.split(":")]
first_block_index, last_block_index = block_indices.split(":")
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
except Exception as e:
raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
block_indices = range(start_block, end_block)
block_indices = range(first_block_index, last_block_index)
num_blocks = len(block_indices)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
@ -231,9 +213,10 @@ class Server:
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval", "dry_run"]:
force_eval = throughput in ["eval", "dry_run"]
self.alloc_timeout = alloc_timeout
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput_info = get_server_throughput(
converted_model_name_or_path,
self.block_config,
@ -242,13 +225,9 @@ class Server:
num_blocks=num_blocks,
quant_type=quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
reachable_via_relay=reachable_via_relay,
force_eval=force_eval,
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
)
if throughput == "dry_run":
logger.info("Finished estimating throughput, exiting")
sys.exit(0)
else:
throughput_info = {"throughput": throughput}
self.server_info = ServerInfo(
@ -258,29 +237,24 @@ class Server:
adapters=tuple(adapters),
torch_dtype=str(torch_dtype).replace("torch.", ""),
quant_type=quant_type.name.lower(),
using_relay=reachable_via_relay,
using_relay=self.dht.client_mode,
**throughput_info,
)
self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
if not os.path.isdir(converted_model_name_or_path):
self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
self.balance_quality = balance_quality
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
self.module_container = None
self.stop = threading.Event()
def _choose_num_blocks(self) -> int:
assert self.device.type in ("cuda", "mps"), (
assert self.device.type == "cuda", (
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
"CPU-only servers in the public swarm are discouraged since they are much slower"
)
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
if num_devices > 1:
assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
memory_per_device = tuple(
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
)
@ -291,10 +265,8 @@ class Server:
"Please launch individual servers on each GPU or set --num_blocks manually to "
"override this exception."
)
elif self.device.type == "cuda":
total_memory = torch.cuda.get_device_properties(self.device).total_memory
else:
total_memory = psutil.virtual_memory().total
total_memory = torch.cuda.get_device_properties(self.device).total_memory
gib = 1024**3
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
@ -320,7 +292,7 @@ class Server:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
logger.info(
f"Server will fill your GPU memory with {num_blocks} transformer blocks. "
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
)
return num_blocks
@ -334,14 +306,12 @@ class Server:
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
alloc_timeout=self.alloc_timeout,
server_info=self.server_info,
model_info=self.model_info,
block_indices=block_indices,
num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
max_chunk_size_bytes=self.max_chunk_size_bytes,
max_alloc_timeout=self.max_alloc_timeout,
inference_max_length=self.inference_max_length,
torch_dtype=self.torch_dtype,
cache_dir=self.cache_dir,
@ -384,7 +354,7 @@ class Server:
self._clean_memory_and_fds()
def _clean_memory_and_fds(self):
self.module_container = None
del self.module_container
gc.collect() # In particular, this closes unused file descriptors
if self.device.type == "cuda":
@ -397,8 +367,6 @@ class Server:
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
f"{reserved_vram / gib:.1f} GiB reserved memory"
)
elif self.device.type == "mps":
torch.mps.empty_cache()
def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
@ -417,10 +385,8 @@ class Server:
module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
def shutdown(self, timeout: Optional[float] = 5):
def shutdown(self):
self.stop.set()
if self.module_container is not None and self.module_container.is_alive():
self.module_container.join(timeout)
if self.reachability_protocol is not None:
self.reachability_protocol.shutdown()
@ -441,13 +407,11 @@ class ModuleContainer(threading.Thread):
converted_model_name_or_path: str,
block_config: PretrainedConfig,
attn_cache_bytes: int,
alloc_timeout: float,
server_info: ServerInfo,
model_info: ModelInfo,
block_indices: List[int],
min_batch_size: int,
max_batch_size: int,
max_chunk_size_bytes: int,
max_alloc_timeout: float,
torch_dtype: torch.dtype,
cache_dir: str,
max_disk_space: int,
@ -463,14 +427,13 @@ class ModuleContainer(threading.Thread):
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread(
module_uids,
dht,
server_info,
model_info,
block_config=block_config,
memory_cache=memory_cache,
update_period=update_period,
@ -514,7 +477,6 @@ class ModuleContainer(threading.Thread):
config=block_config,
memory_cache=memory_cache,
backend_dtype=torch_dtype,
max_chunk_size_bytes=max_chunk_size_bytes,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
@ -590,7 +552,6 @@ class ModuleContainer(threading.Thread):
request_timeout=request_timeout,
session_timeout=session_timeout,
step_timeout=step_timeout,
quant_type=QuantType[server_info.quant_type.upper()],
)
for i in range(num_handlers)
]
@ -679,7 +640,6 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids: List[str],
dht: DHT,
server_info: ServerInfo,
model_info: ModelInfo,
*,
block_config: PretrainedConfig,
memory_cache: MemoryCache,
@ -692,26 +652,20 @@ class ModuleAnnouncerThread(threading.Thread):
self.module_uids = module_uids
self.dht = dht
self.server_info = server_info
self.model_info = model_info
self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
self.bytes_per_token //= block_config.num_key_value_groups
self.update_period = update_period
self.expiration = expiration
self.trigger = threading.Event()
self.dht_prefix = parse_uid(module_uids[0])[0]
block_indices = [parse_uid(uid)[1] for uid in module_uids]
self.server_info.start_block = min(block_indices)
self.server_info.end_block = max(block_indices) + 1
self.max_pinged = max_pinged
self.next_uids = [
f"{self.dht_prefix}{UID_DELIMITER}{i}"
for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)
]
dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
start_block, end_block = min(block_indices), max(block_indices) + 1
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None:
@ -735,19 +689,10 @@ class ModuleAnnouncerThread(threading.Thread):
)
if self.server_info.state == ServerState.OFFLINE:
break
if not self.dht_prefix.startswith("_"): # Not private
self.dht.store(
key="_petals.models",
subkey=self.dht_prefix,
value=self.model_info.to_dict(),
expiration_time=get_dht_time() + self.expiration,
)
delay = self.update_period - (time.perf_counter() - start_time)
if delay < 0:
logger.warning(
f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
)
logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
self.trigger.wait(max(delay, 0))
self.trigger.clear()
@ -759,11 +704,12 @@ class ModuleAnnouncerThread(threading.Thread):
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
pinged_servers.discard(self.dht.peer_id)
# Sample servers hosting the block after the last one (most likely continuations) separately
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
if module_infos[-1] is not None:
# Sample servers hosting the block after the last one (most likely continuations) separately
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
self.ping_aggregator.ping(list(pinged_servers))

@ -9,6 +9,7 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
import torch
from hivemind import get_logger
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
logger = get_logger(__name__)
@ -26,7 +27,7 @@ class Task:
return self.future._uid
class PrioritizedTaskPool(threading.Thread):
class PrioritizedTaskPool(TaskPoolBase):
"""
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@ -56,41 +57,52 @@ class PrioritizedTaskPool(threading.Thread):
daemon=True,
start=False,
):
super().__init__(daemon=daemon, name=name)
self.process_func = process_func
# the lower the priority is, the more urgent it is to process this pool
self._priority = mp.Value(ctypes.c_double, 1.0)
super().__init__(process_func, daemon=daemon, name=name)
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.device = device
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
self._prioritizer_thread = threading.Thread(
name=self.name + "_prioritizer",
target=self._prioritize_tasks,
args=[self.submitted_tasks, self._ordered_tasks],
daemon=True,
)
self._dispatched_tasks = {}
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
self._stop = mp.Event()
if start:
self.start()
def run(self):
@staticmethod
def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
"""Read tasks from incoming queue and put them into a local priority queue"""
while True:
task = self.submitted_tasks.get()
task = submitted_tasks.get()
if task is None:
logger.debug("Shutting down prioritizer thread")
break
self._ordered_tasks.put(task, block=True)
ordered_tasks.put(task, block=True)
def start(self):
assert not self.is_alive() and not self._prioritizer_thread.is_alive()
self._prioritizer_thread.start()
super().start()
def terminate(self):
"""An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
self.shutdown()
def shutdown(self, timeout: float = 3):
self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
self._stop.set()
def shutdown(self):
self.submitted_tasks.put(None) # Shuts down self.run()
self.join(timeout)
if self.is_alive():
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
self.terminate()
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
"""Add task to this pool's queue, return Future for its output"""
@ -151,6 +163,9 @@ class PrioritizedTaskPool(threading.Thread):
else:
task.future.set_exception(exception)
def run(self, *args, **kwargs):
self._stop.wait()
@property
def empty(self):
return not self.batch_receiver.poll()

@ -13,8 +13,9 @@ class TaskPrioritizerBase(ABC):
class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
# Inference steps go first since they are more latency-sensitive
if kwargs.get("type") == "inference":
return 1.0
return 2.0 # Forward, backward
return 1.0 # inference steps go first since they are more latency-sensitive
return 2.0 # forward, backward

@ -9,14 +9,12 @@ from pathlib import Path
from typing import Dict, Optional, Sequence, Union
import torch
import torch.mps
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.misc import DUMMY_KEY_PAST
logger = get_logger(__name__)
@ -43,8 +41,6 @@ def get_server_throughput(
num_blocks: int,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
reachable_via_relay: bool,
relay_penalty: float = 0.2,
force_eval: bool = False,
cache_dir: Optional[str] = None,
) -> Dict[str, float]:
@ -53,11 +49,11 @@ def get_server_throughput(
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, "throughput.lock")
cache_path = Path(cache_dir, "throughput_v5.json")
cache_path = Path(cache_dir, "throughput_v4.json")
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb+") as lock_fd:
with open(lock_path, "wb") as lock_fd:
logger.info("Loading throughput info")
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
# The OS will release the lock when lock_fd is closed or the process is killed
@ -98,12 +94,9 @@ def get_server_throughput(
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
average_blocks_used = (num_blocks + 1) / 2
throughput = throughput_info["forward_rps"] / average_blocks_used
network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1)
throughput = min(throughput, network_rps)
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
throughput_info["throughput"] = throughput
logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")
logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
return throughput_info
@ -116,10 +109,13 @@ def measure_throughput_info(
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
) -> Dict[str, float]:
"""Measure network and compute throughput in forward pass tokens per second"""
logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
)
return {
throughput_info = {
"inference_rps": measure_compute_rps(
config,
device,
@ -140,39 +136,37 @@ def measure_throughput_info(
n_steps=10,
inference=False,
),
"network_rps": measure_network_rps(config),
}
try:
throughput_info["network_rps"] = measure_network_rps(config)
except Exception as e:
logger.info(f"Network throughput is not available: {e}")
return throughput_info
def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
process.start()
if not pipe_recv.poll(timeout):
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()
if "exception" in network_info:
raise RuntimeError(f"speedtest failed: {network_info['exception']}")
def measure_network_rps(
config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
) -> Optional[float]:
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
try:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
process.start()
if not pipe_recv.poll(timeout):
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()
if "exception" in network_info:
raise RuntimeError(f"speedtest failed: {network_info['exception']}")
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise RuntimeError("speedtest has returned network_rps == 0")
logger.info(
f"Network throughput: {network_rps:.1f} tokens/sec "
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
)
return network_rps
except RuntimeError as e:
logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s")
return default_speed / bits_per_request
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise RuntimeError("speedtest has returned network_rps == 0")
logger.info(
f"Network throughput: {network_rps:.1f} RPS "
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
)
return network_rps
def _measure_bits_per_second(pipe_send: mp.Pipe):
@ -198,31 +192,21 @@ def measure_compute_rps(
n_steps: int,
inference: bool,
) -> float:
device = torch.device(device)
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = get_model_block(config)
block = block.to(dtype)
block = config.block_class(config).to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
cache = (DUMMY_KEY_PAST.to(dtype=dtype, device=device), DUMMY_KEY_PAST.to(dtype=dtype, device=device))
cache = None
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
for step in range(n_steps + 1):
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
# Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None
cache = step(cache)
synchronize(device)
start_time = time.perf_counter()
for _ in range(n_steps):
cache = step(cache)
synchronize(device)
elapsed = time.perf_counter() - start_time
start_time = time.perf_counter()
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
if step >= 1: # Skip the 1st step to exclude the initialization time
elapsed += time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
devices_repr = get_device_name(device)
@ -231,21 +215,14 @@ def measure_compute_rps(
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
logger.info(
f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block "
f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block "
f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})"
)
return device_rps
def synchronize(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize(device)
elif device.type == "mps":
torch.mps.synchronize()
def get_device_name(device: torch.device) -> str:
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

@ -4,4 +4,3 @@ from petals.utils.auto_config import (
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos

@ -1,14 +1,12 @@
import os
import re
from dataclasses import dataclass
from typing import Optional, Type, Union
from hivemind import get_logger
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from petals.utils.hf_auth import always_needs_auth
logger = get_logger(__name__)
@dataclass
class _ModelClasses:
@ -33,12 +31,8 @@ class _AutoDistributedBase:
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
if (
always_needs_auth(model_name_or_path)
and kwargs.get("token") is None
and kwargs.get("use_auth_token") is None
):
kwargs["use_auth_token"] = True
if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs:
kwargs["token"] = True
config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
if config.model_type not in _CLASS_MAPPING:
@ -51,44 +45,17 @@ class _AutoDistributedBase:
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
class DefaultRevisionMixin:
"""
Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).
TII models were recently converted to this format but then reverted back due to compatibility issues.
We chose to support only the new format since HF staff promised to eventually convert these models
to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602
Until it happens, we override the default `main` revision for the TII repos with the commit
pointing out to the model in the in-library format.
"""
DEFAULT_REVISIONS = {
"tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
"tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
"tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
"tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
}
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs
):
if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:
revision = cls.DEFAULT_REVISIONS[model_name_or_path]
logger.info(f"Loading {model_name_or_path}, revision {revision}")
return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)
class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):
class AutoDistributedConfig(_AutoDistributedBase):
_mapping_field = "config"
class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):
class AutoDistributedModel(_AutoDistributedBase):
_mapping_field = "model"
class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase):
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
_mapping_field = "model_for_causal_lm"
class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"

@ -1,76 +0,0 @@
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
"""Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
assert not isinstance(callable, torch.nn.Module)
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
raise RuntimeError(
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
)
flatten_arg, _ = _tree_flatten(sample_args)
flatten_sample_args = tuple(flatten_arg)
assert all(
isinstance(arg, torch.Tensor) for arg in flatten_arg
), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
len_user_args = len(sample_args)
static_input_surface = flatten_sample_args
graph = torch.cuda.CUDAGraph()
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(callable(*sample_args))
del outputs
torch.cuda.current_stream().wait_stream(s)
# Capture forward graph
with torch.cuda.graph(graph):
outputs = callable(*sample_args)
flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
static_outputs = tuple(flatten_outputs)
def make_graphed_function(
graph,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
):
def replay_graph(*inputs):
# At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
def functionalized(*user_args):
# Runs the autograd function with inputs == all inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
flatten_user_args, _ = _tree_flatten(user_args)
out = replay_graph(*flatten_user_args)
return _tree_unflatten(out, output_unflatten_spec)
return functionalized
# Put together the final graphed callable
graphed = make_graphed_function(
graph,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
)
return graphed

@ -1,153 +0,0 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import (
CHAIN_DELIMITER,
UID_DELIMITER,
ModuleUID,
RemoteModuleInfo,
RemoteSpanInfo,
ServerInfo,
ServerState,
parse_uid,
)
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[RemoteModuleInfo], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[RemoteModuleInfo]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
for module_info in modules:
metadata = found[module_info.uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
continue
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
module_info.servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
return modules
def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
num_blocks = len(module_infos)
spans = {}
for block_idx, module_info in enumerate(module_infos):
for peer_id, server_info in sorted(module_info.servers.items()):
if server_info.state.value < min_state.value:
continue
if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
)
if server_info.start_block is not None and server_info.end_block is not None:
spans[peer_id].start = max(server_info.start_block - block_offset, 0)
spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
elif spans[peer_id].state == server_info.state:
spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
return spans

@ -22,7 +22,7 @@ def _blocks_lock(cache_dir: Optional[str], mode: int):
lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb+") as lock_fd:
with open(lock_path, "wb") as lock_fd:
fcntl.flock(lock_fd.fileno(), mode)
# The OS will release the lock when lock_fd is closed or the process is killed
yield

@ -0,0 +1,128 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
TokenIds = torch.Tensor
HypoIds = torch.Tensor
class DecodingAlgorithm(ABC):
"""
An abstract class for decoding algorithms. Describes the base function of those algorithms:
they have to select new tokens and provide the corresponding hypotheses.
"""
@abstractmethod
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
"""
pass
class GreedyAlgorithm(DecodingAlgorithm):
"""
The simplest algorithm for decoding. It selects the most probable token.
"""
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
Returns the most probable token. The second returned object is always a range of integers
from 0 to batch_size - 1.
"""
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
class SamplingAlgorithm(DecodingAlgorithm):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
"""
logits[indices_to_remove] = -float("Inf")
probs = torch.softmax(logits / self.temperature, -1)
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
return self.sample(logits, indices_to_remove)
class TopKAlgorithm(SamplingAlgorithm):
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_k = top_k
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
return self.sample(logits, indices_to_remove)
class NucleusAlgorithm(SamplingAlgorithm):
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_p = top_p
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
probs = torch.softmax(sorted_logits / self.temperature, -1)
cumulative_probs = torch.cumsum(probs, dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
return self.sample(logits, indices_to_remove)
class BeamSearchAlgorithm(DecodingAlgorithm):
def __init__(self, num_beams: int, batch_size: int) -> None:
self.num_beams = num_beams
self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)]
def __call__(self, logits: torch.Tensor):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.log_softmax(sorted_logits, -1)
if len(self._batch_beams[0]) > 0:
for batch_idx in range(self.batch_size):
new_beams = []
cur_beams = self._batch_beams[batch_idx]
for beam_idx in range(len(cur_beams)):
probs_idx = batch_idx + beam_idx * self.batch_size
new_beam = cur_beams[beam_idx]
for hypo_idx in range(self.num_beams):
new_beams.append(
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
)
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
else:
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
return_hypos = []
return_tokens = []
for batch_idx in range(self.batch_size):
cur_beam = self._batch_beams[batch_idx]
return_hypos.append(list())
return_tokens.append(list())
for beam in cur_beam:
beam_idx = beam[1] // self.num_beams
hypo_idx = batch_idx + beam_idx * self.batch_size
token_idx = beam[1] % self.num_beams
return_hypos[-1].append(hypo_idx)
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
return torch.tensor(return_tokens), torch.tensor(return_hypos)

@ -0,0 +1,51 @@
from abc import ABC
import torch
class ABCBloomConstraint(ABC):
"""
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
"""
def __init__(self) -> None:
pass
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
"""
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
:param tokens_id: The token id of the last chosen token.
:param logits: The logits from the Bloom model.
:param hypo_ids: The hypothesis ids of the last tokens.
"""
pass
class EosConstraint(ABCBloomConstraint):
"""
This constrained repeats EOS token if it was generated on the previous step.
Args:
prefix: The prefix of the sequence.
eos_token_id: The id of the end of sentence token.
pad_token_id: The id of the padding token.
min_logits: The minimum logits that can be generated. Default: -1e6.
"""
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
self.eos_token_id = eos_token_id
self.min_logits = min_logits
self.past_tokens = None
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
if self.past_tokens is not None:
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
logits += self.min_logits * mask
logits[mask[:, 0], self.eos_token_id] = 0
if tokens_id is not None:
self.past_tokens = tokens_id
self.wait_until_starting -= 1
return logits

@ -2,28 +2,6 @@ import torch
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
DUMMY_KEY_PAST = torch.empty((0, 0, 0))
def is_dummy(tensor: torch.Tensor) -> bool:
def is_dummy(tensor: torch.Tensor):
return tensor.numel() == 0
SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4}
def get_size_in_bytes(dtype: torch.dtype) -> int:
if dtype in SPECIAL_DTYPE_SIZES:
return SPECIAL_DTYPE_SIZES[dtype]
get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8
def docstring_from(source):
def add_docstring(dest):
dest.__doc__ = source.__doc__
return dest
return add_docstring

@ -1,49 +0,0 @@
from typing import Any, Dict, List, Tuple
import torch
from hivemind import nested_flatten, nested_pack
# TODO: Move functions to hivemind
def _mark_masked_tensor(index: int) -> bytes:
return b"__T" + str(index).encode()
def _is_masked_tensor(item: Any) -> bool:
return isinstance(item, bytes) and item.startswith(b"__T")
def _get_tensor_index(item: bytes) -> int:
return int(item[3:])
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
"""
Check the function's arguments and pack all tensors into different flattened lists.
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
"""
masked_flat_values, flat_tensors, tensor_to_index = [], [], {}
for value in nested_flatten((args, kwargs)):
if isinstance(value, torch.Tensor):
tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))
if tensor_index == len(flat_tensors):
flat_tensors.append(value)
masked_flat_values.append(_mark_masked_tensor(tensor_index))
else:
masked_flat_values.append(value)
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
"""
Restore arguments after `pack_args_kwargs` function.
:returns: list of args and dict of kwargs
"""
return nested_pack(
(
value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]
for value in nested_flatten(args_structure)
),
args_structure,
)

@ -10,23 +10,23 @@ import transformers
from accelerate import init_empty_weights
from hivemind.utils.logging import get_logger
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.config import PeftConfig
from peft.tuners import lora
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
def check_peft_repository(repo_id: str) -> bool:
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
fs = HfFileSystem()
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
return len(list_of_files) > 0
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
@ -155,15 +155,15 @@ class AdapterContextMixin:
using_adapter = AdapterContextMixin.using_adapter
class LoraLinear(AdapterContextMixin, lora.Linear):
class LoraLinear(lora.Linear, AdapterContextMixin):
"""LoRA linear layer that uses adapter selected via using_adapter"""
class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block(
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = get_model_block(block_config)
block = block_config.block_class(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)
@ -284,5 +284,5 @@ def estimate_adapter_memory_per_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
)
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
return adapter_parameters * bytes_per_parameter

@ -24,10 +24,7 @@ async def ping(
start_time = time.perf_counter()
await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
return time.perf_counter() - start_time
except Exception as e:
if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays)
return time.perf_counter() - start_time
except Exception:
logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
return math.inf

@ -0,0 +1,25 @@
import argparse
from datetime import datetime
from huggingface_hub import delete_repo, list_models
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument("--dry_run", action="store_true")
args = parser.parse_args()
for model in list_models(author=args.author, full=True):
last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
if model.modelId.endswith("-main") or "/test-" not in model.modelId:
continue # remove only test models
if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
if args.dry_run:
print(f"{model.modelId} can be deleted")
else:
delete_repo(repo_id=model.modelId, token=args.use_auth_token)

Binary file not shown.

@ -3,13 +3,10 @@ import sys
import pytest
import torch
from hivemind import nested_compare, nested_flatten
from petals import AutoDistributedConfig
from petals.server.throughput import measure_compute_rps
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
from test_utils import MODEL_NAME
@ -32,9 +29,6 @@ def test_bnb_not_imported_when_unnecessary():
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
if tensor_parallel and config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps(
config,
@ -47,29 +41,3 @@ def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: boo
inference=inference,
)
assert isinstance(compute_rps, float) and compute_rps > 0
@pytest.mark.forked
def test_pack_inputs():
x = torch.ones(3)
y = torch.arange(5)
z = DUMMY
args = (x, z, None, (y, y), z)
kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)})
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
assert len(flat_tensors) == 5
assert all(isinstance(t, torch.Tensor) for t in flat_tensors)
restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
assert len(restored_args) == len(args)
assert torch.all(restored_args[0] == x).item() and restored_args[2] is None
assert nested_compare((args, kwargs), (restored_args, restored_kwargs))
for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):
if isinstance(original, torch.Tensor):
assert torch.all(original == restored)
else:
assert original == restored

@ -3,41 +3,36 @@ import random
import pytest
import torch
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals import DistributedBloomConfig, RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
block_index = random.randint(0, config.num_hidden_layers - 1)
remote_block = remote_sequential[block_index]
for block_index in random.sample(range(config.num_hidden_layers), 3):
remote_block = remote_sequential[block_index]
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
outputs_forward = remote_block(inputs)
inputs = torch.randn(1, 8, config.hidden_size)
outputs_forward = remote_block(inputs)
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
# Test long inference (unmerged inference pools)
outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
# Test short inference (merged inference pools)
for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference = torch.cat(outputs_inference, dim=1)
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference = torch.cat(outputs_inference, dim=1)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

@ -1,184 +0,0 @@
import asyncio
import multiprocessing as mp
import random
import time
from typing import Optional
import pytest
import pytest_asyncio # make sure the module exists; otherwise the test will be skipped
import torch
from hivemind import TensorDescriptor
from petals.server.memory_cache import AllocationFailed, MemoryCache
from petals.utils.misc import get_size_in_bytes
def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
if dtype is None:
dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
elem_size_bytes = get_size_in_bytes(dtype)
descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
return descr
@pytest.mark.asyncio
async def test_cache_timeout():
cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
cache.runtime_pid += 1 # pretend we're another process
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
pass
async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
t_start = time.perf_counter()
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
pass
assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
pass
t_start = time.perf_counter()
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0): # exceeds max timeout
pass
assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
# test memory allocation when another task frees the memory
async def _klog_the_cache():
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
pass
large_alloc_task = asyncio.create_task(_klog_the_cache())
t_start = time.perf_counter()
await asyncio.sleep(0.05) # wait for large alloc to enqueue
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): # exceeds max timeout
pass # this memory should allocate once the background task clears the queue
assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
with pytest.raises(AllocationFailed):
await large_alloc_task
# test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
large_alloc_task = asyncio.create_task(_klog_the_cache())
t_start = time.perf_counter()
await asyncio.sleep(0.05) # wait for large alloc to enqueue
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
pass # this memory should allocate once the background task clears the queue
assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
with pytest.raises(AllocationFailed):
await large_alloc_task
@pytest.mark.asyncio
async def test_unlimited_timeout():
cache = MemoryCache(max_size_bytes=1024)
cache.runtime_pid += 1 # pretend we're another process
t_start = time.perf_counter()
async def _klog_the_cache():
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
await asyncio.sleep(0.5)
alloc_task = asyncio.create_task(_klog_the_cache())
await asyncio.sleep(0.1)
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
await alloc_task
assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
@pytest.mark.asyncio
async def test_cache_usage():
cache = MemoryCache(max_size_bytes=2048)
alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
with pytest.raises(AssertionError):
async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
pass # fails because cache must be allocated from another process
descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes
descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes
descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool)) # 33 bytes
descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64)) # 0 bytes
descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes
descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes
async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
loop = asyncio.get_event_loop()
async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
pipe_sender.send(handles)
await loop.run_in_executor(None, dealloc_event.wait)
async def _allocate_af():
alloc_event.wait()
allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
await allocate_a_task
allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache
await allocate_f_task
alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
alloc_process1.start()
async def _allocate_bcde():
alloc_event.wait()
await asyncio.sleep(0.1) # ensure that the other tensor is always allocated (and sent through pipe) first
allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
alloc_process2.start()
assert cache.current_size_bytes == 0
alloc_event.set()
(handle_a,) = pipe_receiver.recv()
handle_b, handle_c, handle_d = pipe_receiver.recv()
with cache.use_cache(handle_a) as (tensor_a,):
assert tensor_a.dtype == torch.uint8
tensor_a[2:5] = torch.tensor((42, 43, 44))
with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
tensor_a += 1
tensor_b[...] = -1.337
assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory
dealloc_bcd_event.set()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 768 # only tensor a should be allocated
with pytest.raises(KeyError):
with cache.use_cache(handle_a, handle_b):
pass # one of handles (c) is deallocated
with pytest.raises(KeyError):
with cache.use_cache(handle_d):
pass # handle_d is deallocated correctly, even though it is never used
with cache.use_cache(handle_a) as (tensor_a,):
assert tuple(tensor_a[2:5]) == (43, 44, 45)
dealloc_a_event.set()
(handle_e,) = pipe_receiver.recv() # e can finally be allocated
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate
with pytest.raises(KeyError):
with cache.use_cache(handle_a):
pass # tensor a is no longer allocated
with cache.use_cache(handle_e) as (tensor_e,):
assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
dealloc_e_event.set()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 1792 # only tensor f is still allocated
dealloc_f_event.set()
alloc_process1.join()
alloc_process2.join()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 0
assert cache.current_size_bytes == 0
assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"

@ -7,16 +7,15 @@
import pytest
import torch
from petals import AutoDistributedConfig
from petals import DistributedBloomConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from petals.utils.misc import DUMMY_KEY_PAST
from test_utils import *
@pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
assert isinstance(remote_blocks, RemoteSequential)
@ -44,7 +43,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
inputs = torch.randn(1, 8, config.hidden_size)
@ -55,14 +54,12 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
dtype = torch.float32
ref_blocks = [
load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
]
outputs_ref = []
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
caches = [cache, cache]
caches = [None, None]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]

@ -3,44 +3,32 @@ import pytest
import torch
import transformers
from hivemind import get_logger
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM
from petals import AutoDistributedModelForCausalLM
from petals import DistributedBloomForCausalLM
from test_utils import *
logger = get_logger(__name__)
@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.fixture
def model():
return AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
@pytest.fixture
def ref_model():
return transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):
if use_peft:
model.config.active_adapter = ADAPTER_NAME
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME,
initial_peers=INITIAL_PEERS,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None,
)
config = model.config
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.num_hidden_layers
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.inference_mode():
parallel_outputs = model.forward(test_inputs).logits
@ -52,124 +40,142 @@ def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empt
recurrent_outputs = []
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
if pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
for t in range(embs.shape[1]):
if t == 4:
recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
elif 4 < t < 9:
continue
else:
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == 2 and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
recurrent_outputs = model.lm_head(recurrent_outputs)
assert torch.allclose(
recurrent_outputs, parallel_outputs, rtol=0, atol=atol
), "Inference differs from forward pass"
ref_outputs = ref_model.forward(test_inputs).logits.float()
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF"
def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):
if not multiple_calls:
return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)
with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
return torch.cat(
[
# Sessions provided both explicitly and implicitly should work
model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),
model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),
model.generate(None, max_new_tokens=1, **kwargs),
],
dim=1,
)
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")
del model, embs, recurrent_outputs
if REF_NAME:
ref_model = transformers.BloomForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
del ref_model, ref_outputs, dummy_mask
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
assert False
@pytest.mark.forked
def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
def test_greedy_generation(max_new_tokens=4):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
)
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
options = dict(max_new_tokens=max_new_tokens, do_sample=False)
for multiple_calls in [False, True]:
for inputs in [inputs_single, inputs_batch]:
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
)
hf_outputs_batch = BloomForCausalLM.greedy_search(
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Greedy search results are not identical to HF in multibatch mode"
@pytest.mark.forked
def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
def test_sampling(sampling_options, max_new_tokens=4):
torch.manual_seed(0)
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng():
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs = BloomForCausalLM.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
)
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
for options in [
dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
]:
options.update(max_new_tokens=max_new_tokens)
for multiple_calls in [False, True]:
for inputs in [inputs_single, inputs_batch]:
torch.manual_seed(0)
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
torch.manual_seed(0)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
@pytest.mark.skipif(
"bloom" not in MODEL_NAME.lower(),
reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
)
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
outputs = make_generate_calls(model, inputs, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
with torch.random.fork_rng():
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs_batch = BloomForCausalLM.sample(
model,
input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens,
logits_warper=logits_warper,
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Sampling results are not identical to HF in multibatch mode"
@pytest.mark.forked
def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
assert inputs.keys() == {"input_ids", "attention_mask"}
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
outputs = torch.cat(
[
model.generate(**inputs, max_new_tokens=2),
model.generate(None, max_new_tokens=max_new_tokens - 2),
],
dim=1,
)
assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"
def test_beam_search_generation(max_new_tokens=4, num_beams=2):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
text = "A cat sat on a mat"
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
beam_scorer = BeamSearchScorer(
batch_size=inputs.size(0),
num_beams=num_beams,
device=inputs.device,
length_penalty=0,
do_early_stopping=False,
)
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = BloomForCausalLM.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"

@ -1,224 +0,0 @@
from typing import Optional, Tuple
import pytest
import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from petals.server.block_utils import get_model_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from test_utils import MODEL_NAME
KVCache = Tuple[torch.Tensor, torch.Tensor]
class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs,
):
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None:
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
past_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs,
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
if self.config.new_decoder_architecture:
key_states = self._expand_states(key_states)
value_states = self._expand_states(value_states)
return (key_states, value_states)
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
if self.config.new_decoder_architecture:
key_states = self._collapse_states(key_states)
value_states = self._collapse_states(value_states)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
return state
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state
class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
past_key_value = layer_past
if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
elif use_cache:
past_key_value = DynamicCache()
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
**kwargs,
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_llama_to_bloom(
present_key_value, batch_size, seq_length_with_past
)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> DynamicCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
past_key_values = ((key_states, value_states),)
return DynamicCache.from_legacy_cache(past_key_values)
def _reorder_cache_from_llama_to_bloom(
self, key_value: DynamicCache, batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value.to_legacy_cache()[0]
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
key_states = key_states.view(*value_states.shape)
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked
def test_optimized_block(device):
if device == "cuda:0" and not torch.cuda.is_available():
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
tensor_parallel_devices = (device,)
dtype = torch.bfloat16
quant_type = QuantType.NONE
block_idx = 1
block = get_model_block(config, layer_idx=block_idx).to(dtype)
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")
unopt_block = convert_block(
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)
unopt_block.load_state_dict(block.state_dict())
cache = unopt_cache = None
with torch.inference_mode():
for length in [10, 1, 1, 1]:
dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length

@ -1,5 +1,4 @@
import multiprocessing as mp
import platform
import time
import pytest
@ -9,30 +8,9 @@ from hivemind.moe.server.runtime import Runtime
from petals.server.task_pool import PrioritizedTaskPool
def _submit_tasks(runtime_ready, pools, results_valid):
runtime_ready.wait()
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
@pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks")
@pytest.mark.forked
def test_priority_pools():
outputs_queue = mp.SimpleQueue()
runtime_ready = mp.Event()
results_valid = mp.Event()
def dummy_pool_func(x):
@ -53,14 +31,27 @@ def test_priority_pools():
PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
)
# Simulate requests coming from ConnectionHandlers
proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
proc.start()
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
runtime.ready = runtime_ready
runtime.start()
def process_tasks():
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
proc = mp.Process(target=process_tasks)
proc.start()
proc.join()
assert results_valid.is_set()
@ -78,5 +69,3 @@ def test_priority_pools():
# 3 - task with priority 2 from pool A
# 4 - task with priority 10 from pool A
# 7 - task with priority 11 from pool B
runtime.shutdown()

@ -4,7 +4,7 @@ import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2
from petals import AutoDistributedConfig
from petals import DistributedBloomConfig
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER
from petals.server.from_pretrained import load_pretrained_block
@ -15,7 +15,7 @@ logger = get_logger(__name__)
@pytest.mark.forked
def test_remote_sequential():
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
grad_proj = torch.randn(1, 5, config.hidden_size)
@ -40,10 +40,10 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
@ -56,7 +56,7 @@ def test_remote_sequential():
(approx_outputs * grad_proj).sum().backward()
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01
absmax = abs(full_grad).max()
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
@ -87,7 +87,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
@pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
@ -126,6 +126,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
(outputs_ref * output_proj).sum().backward()
assert input_prompts_ref.grad is not None
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2)
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
assert intermediate_prompts_ref.grad is not None
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)

@ -5,7 +5,7 @@ import pytest
import torch
from hivemind import DHT, get_logger
from petals import AutoDistributedConfig
from petals import DistributedBloomConfig
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER
from test_utils import *
@ -16,7 +16,7 @@ logger = get_logger(__name__)
@pytest.mark.forked
@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
def test_sequence_manager_basics(mode: str):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
sequential = RemoteSequential(config, dht=dht)
shutdown_evt = threading.Event()

@ -4,16 +4,14 @@ import hivemind
import pytest
import torch
from petals import AutoDistributedConfig, RemoteSequential
from petals import DistributedBloomConfig, RemoteSequential
from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import *
@pytest.mark.forked
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)

@ -14,11 +14,8 @@ from test_utils import MODEL_NAME
@pytest.mark.parametrize("custom_config", [True, False])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
def test_tp_block(devices, custom_config):
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
if model_config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
block_index = random.randint(0, 10)
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
tp_config = None

Loading…
Cancel
Save