pull/493/head
Danny Boy 10 months ago
commit 15a715c5ad

@ -10,10 +10,20 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
include:
- { model: 'bigscience/bloom-560m', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', python-version: '3.9' }
- { model: 'bigscience/bloom-560m', python-version: '3.10' }
- { model: 'bigscience/bloom-560m', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
fail-fast: false
timeout-minutes: 15
steps:
- name: Increase swap space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
@ -31,44 +41,77 @@ jobs:
pip install .[dev]
- name: Test
run: |
export MODEL_NAME=bigscience/bloom-560m
export REF_NAME=bigscience/bloom-560m
export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--adapters $ADAPTER_NAME &> server1.log &
SERVER1_PID=$!
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
bash -c 'while true; do free -h && sleep 30s; done' &
RAM_WATCH_PID=$!
sleep 5 # wait for the first server to initialize DHT
# [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
BOOTSTRAP_PID=$!
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
SERVER2_PID=$!
sleep 5 # wait for DHT init
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
sleep 10 # wait for the 1st server to choose blocks
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
--identity_path tests/server2.id \
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
SERVER2_PID=$!
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
SERVER3_PID=$!
# ^-- chunking test
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)
tail -n 100 -f server*.log &
sleep 5 # wait for the log files to appear
tail -n 100 -f bootstrap.log server*.log &
LOGGER_PID=$!
sleep 30 # wait for servers to download layers
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
# [Step 3] Run PyTest
pytest tests --durations=0 --durations-min=1.0 -v
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests
# [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3
python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --n_steps 1
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
# [Step 5] Clean up
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests
kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
echo "Done!"

@ -1,19 +1,22 @@
<p align="center">
<img src="https://i.imgur.com/7eR7Pan.png" width="400"><br>
Run large language models at home, BitTorrent-style.<br>
Fine-tuning and inference <a href="https://github.com/bigscience-workshop/petals#benchmarks">up to 10x faster</a> than offloading<br><br>
<a href="https://pypi.org/project/petals/"><img src="https://img.shields.io/pypi/v/petals.svg?color=green"></a><br>
Fine-tuning and inference <a href="https://github.com/bigscience-workshop/petals#benchmarks">up to 10x faster</a> than offloading
<br><br>
<a href="https://pypi.org/project/petals/"><img src="https://img.shields.io/pypi/v/petals.svg?color=green"></a>
<a href="https://discord.gg/tfHfe8B34k"><img src="https://img.shields.io/discord/865254854262652969?label=discord&logo=discord&logoColor=white"></a>
<br>
</p>
Generate text with distributed [LLaMA 2](https://ai.meta.com/llama/) ([70B](https://huggingface.co/meta-llama/Llama-2-70b-hf), [70B-Chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
```python
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
model_name = "enoch/llama-65b-hf"
model_name = "stabilityai/StableBeluga2"
# You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf",
# "bigscience/bloom", or "bigscience/bloomz"
# repos with LLaMA-65B, "bigscience/bloom", or "bigscience/bloomz"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
@ -30,10 +33,12 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
🦙 **Want to run LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
📋 **Terms of use.** Make sure you follow the model license (see the ones for [LLaMA 2](https://bit.ly/llama2-license), [LLaMA](https://bit.ly/llama-license) and [BLOOM](https://bit.ly/bloom-license)).
📋 **Terms of use.** Make sure you follow the model license (see [LLaMA 2](https://bit.ly/llama2-license), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2/blob/main/LICENSE.txt), [LLaMA](https://bit.ly/llama-license), and [BLOOM](https://bit.ly/bloom-license)).
🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
### Connect your GPU and increase Petals capacity
Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models!
@ -43,7 +48,7 @@ Petals is a community-run system &mdash; we rely on people sharing their GPUs. Y
```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install git+https://github.com/bigscience-workshop/petals
python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b
python -m petals.cli.run_server stabilityai/StableBeluga2
```
🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
@ -52,10 +57,10 @@ python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanac
```bash
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
python -m petals.cli.run_server --port 31330 enoch/llama-65b-hf --adapters timdettmers/guanaco-65b
python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2
```
These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingface.co/timdettmers/guanaco-65b) adapters on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`:
@ -63,7 +68,7 @@ These commands host a part of LLaMA-65B with optional [Guanaco](https://huggingf
python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE
```
💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)!
💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
@ -79,8 +84,8 @@ Basic tutorials:
Useful tools and advanced guides:
- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/borzunov/chat.petals.dev)
- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/borzunov/health.petals.dev)
- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
- Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
- Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
@ -91,8 +96,8 @@ Learning more:
## How does it work?
- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning.
- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
<p align="center">

@ -15,15 +15,15 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--n_steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
if args.n_processes == "n_gpus":
@ -31,15 +31,19 @@ def main():
else:
args.n_processes = int(args.n_processes)
processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)]
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
logger.info(f"Final result: {speed=:.2f}")
@torch.inference_mode()
def benchmark_forward(process_idx, args):
def benchmark_forward(process_idx, args, result_pipe):
model = AutoDistributedModel.from_pretrained(
args.model,
initial_peers=args.initial_peers,
@ -64,7 +68,7 @@ def benchmark_forward(process_idx, args):
speed = input_ids.numel() / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
logger.info(f"Final result: {process_idx=} {speed=:.2f}")
result_pipe.send(speed)
if __name__ == "__main__":

@ -16,13 +16,13 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
if args.n_processes == "n_gpus":
@ -30,15 +30,19 @@ def main():
else:
args.n_processes = int(args.n_processes)
processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)]
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
logger.info(f"Final result: {speed=:.2f}")
@torch.inference_mode()
def benchmark_inference(process_idx, args):
def benchmark_inference(process_idx, args, result_pipe):
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
@ -61,7 +65,7 @@ def benchmark_inference(process_idx, args):
speed = 1 / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
logger.info(f"Final result: {process_idx=} {speed=:.2f}")
result_pipe.send(speed)
if __name__ == "__main__":

@ -15,18 +15,18 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--task", type=str, default="cls")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--pre_seq_len", type=int, default=16)
parser.add_argument("--n_steps", type=int, default=10)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
parser.add_argument("--task", type=str, default="cls", help="Training task type")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
assert args.task in ["cls", "causal_lm"]
@ -36,14 +36,18 @@ def main():
else:
args.n_processes = int(args.n_processes)
processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)]
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
def benchmark_training(process_idx, args):
def benchmark_training(process_idx, args, result_pipe):
if args.task == "cls":
model = AutoDistributedModelForSequenceClassification.from_pretrained(
args.model,
@ -96,7 +100,7 @@ def benchmark_training(process_idx, args):
bwd_speed = input_ids.numel() / np.mean(bwd_times)
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
result_pipe.send((fwd_speed, bwd_speed))
if __name__ == "__main__":

@ -92,9 +92,6 @@
},
"outputs": [],
"source": [
"# Choose a model you'd like to prompt-tune. We recommend starting with\n",
"# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n",
"# The code below uses LLaMA-65B.\n",
"MODEL_NAME = \"enoch/llama-65b-hf\"\n",
"\n",
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",

@ -32,7 +32,7 @@ packages = find:
python_requires = >=3.8
install_requires =
torch>=1.12
bitsandbytes==0.40.1.post1
bitsandbytes==0.41.1
accelerate>=0.20.3,<0.21.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3

@ -11,7 +11,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.0.1"
__version__ = "2.0.1.post2"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -7,8 +7,8 @@ This script may be used for launching lightweight CPU machines serving as bootst
This may be eventually merged to the hivemind upstream.
"""
import argparse
import time
from argparse import ArgumentParser
from secrets import token_hex
from hivemind.dht import DHT, DHTNode
@ -35,7 +35,7 @@ async def report_status(dht: DHT, node: DHTNode):
def main():
parser = ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--initial_peers",
nargs="*",
@ -73,7 +73,9 @@ def main():
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
)
parser.add_argument(
"--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls"
"--use_auto_relay",
action="store_true",
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"

@ -122,7 +122,7 @@ def main():
help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS,
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
group.add_argument('--new_swarm', action='store_true',
help='Start a new private swarm (i.e., do not connect to any initial peers)')
@ -158,7 +158,7 @@ def main():
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
parser.add_argument("--adapters", nargs='+', default=(),
parser.add_argument("--adapters", nargs='*', default=(),
help="List of pre-loaded LoRA adapters that can be used for inference or training")
# fmt:on

@ -1,4 +1,4 @@
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase

@ -0,0 +1,31 @@
import dataclasses
from typing import Optional, Sequence, Union
from hivemind import PeerID
from petals.constants import PUBLIC_INITIAL_PEERS
@dataclasses.dataclass
class ClientConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update

@ -7,22 +7,18 @@ import uuid
from typing import AsyncIterator, List, Optional, Tuple
import torch
from hivemind import (
MSGPackSerializer,
anext,
deserialize_torch_tensor,
get_logger,
nested_flatten,
serialize_torch_tensor,
)
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -36,7 +32,7 @@ class _ServerInferenceSession:
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
@ -63,7 +59,7 @@ class _ServerInferenceSession:
@classmethod
async def create(
cls,
config: SequenceManagerConfig,
config: ClientConfig,
p2p: P2P,
span: RemoteSpanInfo,
uid: ModuleUID,
@ -75,7 +71,7 @@ class _ServerInferenceSession:
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.request_timeout,
config.connect_timeout,
)
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
@ -128,13 +124,13 @@ class _ServerInferenceSession:
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
input_tensors = (inputs, prompts, hypo_ids)
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
@ -144,13 +140,25 @@ class _ServerInferenceSession:
if next_servers:
request_metadata["next_servers"] = next_servers
request_metadata["args_structure"] = args_structure
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
for tensor, proto in zip(input_tensors, inference_schema)
],
metadata=MSGPackSerializer.dumps(request_metadata),
)

@ -12,53 +12,55 @@ from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_U
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.config import ClientConfig
from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=timeout,
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=timeout,
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
outputs = aiter_with_timeout(outputs, timeout)
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout)
outputs = aiter_with_timeout(outputs, config.request_timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
grad_inputs = aiter_with_timeout(grad_inputs, timeout)
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout)
grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
@ -67,7 +69,7 @@ async def run_remote_forward(
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
timeout: float,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
@ -83,26 +85,20 @@ async def run_remote_forward(
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)
# Modify forward_schema to support prompts
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
compression = args_schema[0].compression
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: create more explicit way to check servers schema and client's structure
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
for tensor, proto in zip(inputs, forward_schema)
)
)
@ -110,7 +106,7 @@ async def run_remote_forward(
size = sum(t.element_size() * t.nelement() for t in inputs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
@ -118,10 +114,8 @@ async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
timeout: float,
*inputs_and_grad_outputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Sequence[torch.Tensor]:
@ -130,16 +124,14 @@ async def run_remote_backward(
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
outputs_schema = rpc_info["outputs_schema"]
compression = args_schema[0].compression
backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
# TODO: create more explicit way to check servers schema and client's structure
assert (
len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
@ -153,5 +145,5 @@ async def run_remote_backward(
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return deserialized_grad_inputs

@ -6,8 +6,9 @@ import torch
from hivemind import DHT, get_logger
from torch import nn
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
from petals.client.routing import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
@ -22,7 +23,7 @@ class RemoteSequential(nn.Module):
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
*,
sequence_manager: Optional[RemoteSequenceManager] = None,
dht: Optional[DHT] = None,

@ -1 +1,2 @@
"""Client-side functions responsible for choosing the best server, """
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -7,7 +7,8 @@ import logging
import random
import threading
import time
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from weakref import WeakMethod
import dijkstar
@ -18,39 +19,27 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
import petals.dht_utils
from petals.client.config import ClientConfig
from petals.client.routing.sequence_info import RemoteSequenceInfo
from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
from petals.utils.dht import get_remote_module_infos
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
logger = get_logger(__name__)
@dataclasses.dataclass
class SequenceManagerConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
use_server_to_server: bool = True # Use direct server-to-server communication
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 5 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
class SequenceManagerConfig(ClientConfig):
def __init__(self, *args, **kwargs):
warnings.warn(
"petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. "
"This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
@dataclasses.dataclass
@ -81,7 +70,7 @@ class RemoteSequenceManager:
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
block_uids: Sequence[ModuleUID],
*,
dht: Optional[DHT] = None,
@ -115,6 +104,9 @@ class RemoteSequenceManager:
self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy()
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None:
@ -127,6 +119,23 @@ class RemoteSequenceManager:
self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
if peer_ids is None:
return None
result = set()
for peer_id in peer_ids:
if isinstance(peer_id, PeerID):
result.add(peer_id)
elif isinstance(peer_id, str):
result.add(PeerID.from_base58(peer_id))
else:
raise TypeError(
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
)
return result
def make_sequence(
self,
start_index: int = 0,
@ -291,9 +300,9 @@ class RemoteSequenceManager:
# This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
def _make_sequence_with_max_throughput(
self, start_index: int, end_index: int, *, relay_penalty: float = 0.5
) -> List[RemoteSpanInfo]:
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
client_server_rtts = self.ping_aggregator.to_dict()
span_sequence = []
current_index = start_index
while current_index < end_index:
@ -301,11 +310,11 @@ class RemoteSequenceManager:
if not candidate_spans:
raise MissingBlocksError(current_index)
# We choose longer servers to minimize the number of hops but leave some randomization
# to distribute the load. We also exclude servers known to be unreachable.
eps = 1e-6
span_weights = np.array(
[
span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty)
for span in candidate_spans
],
[span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
dtype=np.float64,
)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
@ -332,7 +341,7 @@ class RemoteSequenceManager:
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
new_block_infos = petals.dht_utils.get_remote_module_infos(
new_block_infos = get_remote_module_infos(
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
)
@ -340,13 +349,13 @@ class RemoteSequenceManager:
if not block_info:
continue
# Apply whitelist, if defined
if self.config.allowed_servers is not None:
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
}
# Apply allow and block lists
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if (self.allowed_servers is None or peer_id in self.allowed_servers)
and (self.blocked_servers is None or peer_id not in self.blocked_servers)
}
# Remove temporarily banned peers, unless there are no peers left
valid_servers = {
@ -368,9 +377,13 @@ class RemoteSequenceManager:
self.state.sequence_info.update_(new_block_infos)
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
middle_servers = [
span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
]
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
@ -461,14 +474,21 @@ class RemoteSequenceManager:
return 0
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
def get_request_metadata(
self, protocol: str, args_structure: Any = None, *args, **kwargs
) -> Optional[Dict[str, Any]]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
:param args: request-specific inputs, typically block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
return dict(
points=self.policy.get_points(protocol, *args, **kwargs),
active_adapter=self.config.active_adapter,
args_structure=args_structure,
)
def shutdown(self):
self._thread.shutdown()

@ -12,10 +12,11 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -67,16 +68,18 @@ async def sequential_forward(
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
metadata = sequence_manager.get_request_metadata(
"rpc_forward", args_structure, span_uids, *flat_tensors
)
(outputs,) = await run_remote_forward(
span_uids,
stub,
sequence_manager.rpc_info,
*inputs_and_prompts,
timeout=sequence_manager.config.request_timeout,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)
@ -149,19 +152,22 @@ async def sequential_backward(
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
flat_tensors, args_structure = pack_args_kwargs(
inputs, *grad_outputs_cpu, prompts[span.start : span.end]
)
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
stub,
sequence_manager.rpc_info,
inputs,
grad_outputs,
prompts[span.start : span.end],
timeout=sequence_manager.config.request_timeout,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)
grad_outputs = [grad_outputs]

@ -6,8 +6,6 @@ import pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
from petals.server.memory_cache import Handle
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
@ -78,6 +76,8 @@ class RemoteSpanInfo:
RPCInfo = Dict[str, Any]
Handle = int
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:

@ -1,124 +1,9 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import warnings
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
warnings.warn(
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules
from petals.utils.dht import *

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.bloom.block import WrappedBloomBlock
logger = get_logger(__name__)
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedBloomBlock
attn_class = BloomAttention
block_prefix = "h"

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.llama.block import WrappedLlamaBlock
logger = get_logger(__name__)
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedLlamaBlock
attn_class = LlamaAttention
block_prefix = "model.layers"

@ -0,0 +1,230 @@
"""
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
"""
from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
import torch
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
from hivemind.utils.nested import nested_flatten
from petals.data_structures import Handle, InferenceMetadata
from petals.server.backend import TransformerBackend
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import unpack_args_kwargs
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
MAX_SHORT_INFERENCE_TOKENS = 128
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
logger = get_logger(__name__)
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, *_ = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def run_rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
inputs, grad_outputs, prompts, *_ = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
async def iterate_rpc_inference(
requested_uids: Sequence[ExpertUID],
requested_backends: Sequence[TransformerBackend],
active_adapter: Optional[str],
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
cache_handles: Sequence[Sequence[Handle]],
*,
max_length: int,
prioritizer: TaskPrioritizerBase,
points: int,
quant_type: QuantType,
args_structure: Any = None,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
assert len(cache_handles) == len(requested_backends)
prefix_length = 0
point_per_piece = points / max_length if max_length > 0 else 0.0
async for request, step_metadata in input_iterator:
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, hypo_ids, *_ = flat_tensors
batch_size, length_increment, _ = hidden_states.shape
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
can_merge_pools = batch_size * length_increment <= merge_max_tokens
priority = prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="short_inference" if can_merge_pools else "inference",
)
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
if hidden_states.numel() > 0:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
if can_merge_pools:
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
else:
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
can_push = not has_prompts
yield output_tensors, can_push
# prepare for next step
prefix_length += length_increment

@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
if dtype not in ("auto", None):
return dtype
if config.torch_dtype not in ("auto", None):
if config.torch_dtype not in ("auto", None, torch.float32):
# If config specifies float32, we override it to the default dtype below
return config.torch_dtype
return torch.bfloat16

@ -6,7 +6,7 @@ import multiprocessing as mp
import sys
from enum import Enum
from itertools import chain
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
import torch
from async_timeout import timeout
@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.convert_block import QuantType
logger = get_logger(__name__)
@ -72,6 +71,7 @@ class TransformerConnectionHandler(ConnectionHandler):
session_timeout: float,
step_timeout: float,
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
quant_type: QuantType,
):
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
@ -89,6 +89,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self._prioritizer = task_prioritizer
self.quant_type = quant_type
async def add_p2p_handlers(self, *args, **kwargs) -> None:
if self._listener_task is None:
@ -147,9 +148,9 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
args_structure = metadata.get("args_structure")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
@ -163,78 +164,30 @@ class TransformerConnectionHandler(ConnectionHandler):
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
)
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
prefix_length = 0
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
first_request = request
background_tasks = set()
async for request, metadata in self._iterate_inference_steps(
first_request, requests, session_id, requested_uids, context
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
input_iterator=self._iterate_inference_steps(
request, requests, session_id, requested_uids, context
),
cache_handles=cache_handles,
max_length=max_length,
prioritizer=self._prioritizer,
points=points,
quant_type=self.quant_type,
args_structure=args_structure,
):
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
priority = self._prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
if not has_prompts:
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
# prepare for next step
prefix_length += length_increment
finally:
self._log_request("rpc_inference.close", requested_uids, context)
@ -404,16 +357,18 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@ -431,16 +386,18 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_output for streaming and respond to client
@ -482,16 +439,18 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await _rpc_backward(
grads = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@ -507,16 +466,18 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await _rpc_backward(
grads = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):
@ -621,105 +582,3 @@ class TransformerConnectionHandler(ConnectionHandler):
result.update(block_info)
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
async def _rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def _rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape

@ -16,12 +16,11 @@ import hivemind
import torch
from hivemind.utils import TensorDescriptor, get_logger
from petals.data_structures import Handle
from petals.utils.asyncio import shield_and_wait
logger = get_logger(__name__)
Handle = int
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""

@ -20,7 +20,6 @@ from transformers import PretrainedConfig
import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -31,6 +30,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.dht import declare_active_modules, get_remote_module_infos
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
from petals.utils.version import get_compatible_model_repo
@ -78,12 +78,12 @@ class Server:
sender_threads: int = 1,
balance_quality: float = 0.75,
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5,
mean_block_selection_delay: float = 5,
token: Optional[Union[str, bool]] = None,
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
dht_client_mode: Optional[bool] = None,
reachable_via_relay: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
adapters: Sequence[str] = (),
@ -129,20 +129,20 @@ class Server:
for block_index in range(self.block_config.num_hidden_layers)
]
if dht_client_mode is None:
if reachable_via_relay is None:
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}")
reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer
logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.num_hidden_layers,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
client_mode=dht_client_mode,
client_mode=reachable_via_relay,
**kwargs,
)
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
@ -201,6 +201,8 @@ class Server:
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()
if num_blocks is not None:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
if block_indices is not None:
try:
first_block_index, last_block_index = block_indices.split(":")
@ -227,6 +229,7 @@ class Server:
num_blocks=num_blocks,
quant_type=quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
reachable_via_relay=reachable_via_relay,
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
)
@ -239,7 +242,7 @@ class Server:
adapters=tuple(adapters),
torch_dtype=str(torch_dtype).replace("torch.", ""),
quant_type=quant_type.name.lower(),
using_relay=self.dht.client_mode,
using_relay=reachable_via_relay,
**throughput_info,
)
@ -294,7 +297,7 @@ class Server:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
logger.info(
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
f"Server will fill your GPU memory with {num_blocks} transformer blocks. "
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
)
return num_blocks
@ -557,6 +560,7 @@ class ModuleContainer(threading.Thread):
request_timeout=request_timeout,
session_timeout=session_timeout,
step_timeout=step_timeout,
quant_type=QuantType[server_info.quant_type.upper()],
)
for i in range(num_handlers)
]
@ -697,7 +701,9 @@ class ModuleAnnouncerThread(threading.Thread):
delay = self.update_period - (time.perf_counter() - start_time)
if delay < 0:
logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
logger.warning(
f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
)
self.trigger.wait(max(delay, 0))
self.trigger.clear()

@ -13,9 +13,10 @@ class TaskPrioritizerBase(ABC):
class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
# Inference steps (especially short ones) go first since they are more latency-sensitive
if kwargs.get("type") == "short_inference":
return 1.0
if kwargs.get("type") == "inference":
return 1.0 # inference steps go first since they are more latency-sensitive
return 2.0 # forward, backward
return 2.0
return 3.0 # Forward, backward

@ -41,6 +41,8 @@ def get_server_throughput(
num_blocks: int,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
reachable_via_relay: bool,
relay_penalty: float = 0.2,
force_eval: bool = False,
cache_dir: Optional[str] = None,
) -> Dict[str, float]:
@ -49,7 +51,7 @@ def get_server_throughput(
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, "throughput.lock")
cache_path = Path(cache_dir, "throughput_v4.json")
cache_path = Path(cache_dir, "throughput_v5.json")
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
@ -94,7 +96,10 @@ def get_server_throughput(
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
average_blocks_used = (num_blocks + 1) / 2
throughput = throughput_info["forward_rps"] / average_blocks_used
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1)
throughput = min(throughput, network_rps)
throughput_info["throughput"] = throughput
logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")
@ -191,6 +196,7 @@ def measure_compute_rps(
n_steps: int,
inference: bool,
) -> float:
device = torch.device(device)
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
@ -199,13 +205,17 @@ def measure_compute_rps(
cache = None
elapsed = 0
for step in range(n_steps + 1):
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
if device.type == "cuda":
torch.cuda.synchronize(device)
start_time = time.perf_counter()
start_time = time.perf_counter()
for step in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
if step >= 1: # Skip the 1st step to exclude the initialization time
elapsed += time.perf_counter() - start_time
if device.type == "cuda":
torch.cuda.synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
devices_repr = get_device_name(device)

@ -4,3 +4,4 @@ from petals.utils.auto_config import (
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos

@ -0,0 +1,124 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules

@ -2,6 +2,8 @@ import torch
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
def is_dummy(tensor: torch.Tensor):
return tensor.numel() == 0

@ -0,0 +1,49 @@
from typing import Any, Dict, List, Tuple
import torch
from hivemind import nested_flatten, nested_pack
# TODO: Move functions to hivemind
def _mark_masked_tensor(index: int) -> bytes:
return b"__T" + str(index).encode()
def _is_masked_tensor(item: Any) -> bool:
return isinstance(item, bytes) and item.startswith(b"__T")
def _get_tensor_index(item: bytes) -> int:
return int(item[3:])
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
"""
Check the function's arguments and pack all tensors into different flattened lists.
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
"""
masked_flat_values, flat_tensors, tensor_to_index = [], [], {}
for value in nested_flatten((args, kwargs)):
if isinstance(value, torch.Tensor):
tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))
if tensor_index == len(flat_tensors):
flat_tensors.append(value)
masked_flat_values.append(_mark_masked_tensor(tensor_index))
else:
masked_flat_values.append(value)
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
"""
Restore arguments after `pack_args_kwargs` function.
:returns: list of args and dict of kwargs
"""
return nested_pack(
(
value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]
for value in nested_flatten(args_structure)
),
args_structure,
)

@ -24,7 +24,10 @@ async def ping(
start_time = time.perf_counter()
await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
return time.perf_counter() - start_time
except Exception:
except Exception as e:
if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays)
return time.perf_counter() - start_time
logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
return math.inf

Binary file not shown.

@ -3,10 +3,13 @@ import sys
import pytest
import torch
from hivemind import nested_compare, nested_flatten
from petals import AutoDistributedConfig
from petals.server.throughput import measure_compute_rps
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
from test_utils import MODEL_NAME
@ -29,6 +32,9 @@ def test_bnb_not_imported_when_unnecessary():
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
if tensor_parallel and config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps(
config,
@ -41,3 +47,29 @@ def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: boo
inference=inference,
)
assert isinstance(compute_rps, float) and compute_rps > 0
@pytest.mark.forked
def test_pack_inputs():
x = torch.ones(3)
y = torch.arange(5)
z = DUMMY
args = (x, z, None, (y, y), z)
kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)})
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
assert len(flat_tensors) == 5
assert all(isinstance(t, torch.Tensor) for t in flat_tensors)
restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
assert len(restored_args) == len(args)
assert torch.all(restored_args[0] == x).item() and restored_args[2] is None
assert nested_compare((args, kwargs), (restored_args, restored_kwargs))
for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):
if isinstance(original, torch.Tensor):
assert torch.all(original == restored)
else:
assert original == restored

@ -3,36 +3,41 @@ import random
import pytest
import torch
from petals import DistributedBloomConfig, RemoteSequential
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
for block_index in random.sample(range(config.num_hidden_layers), 3):
remote_block = remote_sequential[block_index]
block_index = random.randint(0, config.num_hidden_layers - 1)
remote_block = remote_sequential[block_index]
inputs = torch.randn(1, 8, config.hidden_size)
outputs_forward = remote_block(inputs)
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
outputs_forward = remote_block(inputs)
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
# Test long inference (unmerged inference pools)
outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference = torch.cat(outputs_inference, dim=1)
# Test short inference (merged inference pools)
for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference = torch.cat(outputs_inference, dim=1)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

@ -7,7 +7,7 @@
import pytest
import torch
from petals import DistributedBloomConfig
from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@ -15,7 +15,7 @@ from test_utils import *
@pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
assert isinstance(remote_blocks, RemoteSequential)
@ -43,7 +43,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
inputs = torch.randn(1, 8, config.hidden_size)

@ -3,29 +3,31 @@ import pytest
import torch
import transformers
from hivemind import get_logger
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM
from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
from petals import DistributedBloomForCausalLM
from petals import AutoDistributedModelForCausalLM
from test_utils import *
logger = get_logger(__name__)
@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME,
initial_peers=INITIAL_PEERS,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None,
)
config = model.config
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.num_hidden_layers
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
@ -63,7 +65,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
del model, embs, recurrent_outputs
if REF_NAME:
ref_model = transformers.BloomForCausalLM.from_pretrained(
ref_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
@ -86,27 +88,29 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
@pytest.mark.forked
def test_greedy_generation(max_new_tokens=4):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
def test_greedy_generation(tokenizer, max_new_tokens=4):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
)
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
)
hf_outputs_batch = BloomForCausalLM.greedy_search(
hf_outputs_batch = HfGenerationMixin.greedy_search(
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
)
assert torch.allclose(
@ -117,13 +121,13 @@ def test_greedy_generation(max_new_tokens=4):
@pytest.mark.forked
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
def test_sampling(sampling_options, max_new_tokens=4):
def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
torch.manual_seed(0)
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng():
remote_outputs = model.generate(
@ -133,7 +137,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs = BloomForCausalLM.sample(
hf_outputs = HfGenerationMixin.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
)
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
@ -149,7 +153,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs_batch = BloomForCausalLM.sample(
hf_outputs_batch = HfGenerationMixin.sample(
model,
input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens,
@ -161,10 +165,9 @@ def test_sampling(sampling_options, max_new_tokens=4):
@pytest.mark.forked
def test_beam_search_generation(max_new_tokens=4, num_beams=2):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
text = "A cat sat on a mat"
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
@ -181,7 +184,7 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
do_early_stopping=False,
)
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = BloomForCausalLM.beam_search(
hf_outputs = HfGenerationMixin.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"

@ -4,7 +4,7 @@ import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2
from petals import DistributedBloomConfig
from petals import AutoDistributedConfig
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER
from petals.server.from_pretrained import load_pretrained_block
@ -15,7 +15,7 @@ logger = get_logger(__name__)
@pytest.mark.forked
def test_remote_sequential():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
grad_proj = torch.randn(1, 5, config.hidden_size)
@ -40,10 +40,10 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
@ -56,7 +56,7 @@ def test_remote_sequential():
(approx_outputs * grad_proj).sum().backward()
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01
absmax = abs(full_grad).max()
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
@ -87,7 +87,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
@pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)

@ -5,7 +5,7 @@ import pytest
import torch
from hivemind import DHT, get_logger
from petals import DistributedBloomConfig
from petals import AutoDistributedConfig
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER
from test_utils import *
@ -16,7 +16,7 @@ logger = get_logger(__name__)
@pytest.mark.forked
@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
def test_sequence_manager_basics(mode: str):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
sequential = RemoteSequential(config, dht=dht)
shutdown_evt = threading.Event()

@ -4,14 +4,16 @@ import hivemind
import pytest
import torch
from petals import DistributedBloomConfig, RemoteSequential
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import *
@pytest.mark.forked
def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)

@ -14,8 +14,11 @@ from test_utils import MODEL_NAME
@pytest.mark.parametrize("custom_config", [True, False])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
def test_tp_block(devices, custom_config):
block_index = random.randint(0, 10)
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
if model_config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
block_index = random.randint(0, 10)
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
tp_config = None

Loading…
Cancel
Save