mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Merge branch 'main' into repetition-penalty
This commit is contained in:
commit
dd677d9e76
6
.github/workflows/check-style.yaml
vendored
6
.github/workflows/check-style.yaml
vendored
@ -9,7 +9,7 @@ jobs:
|
|||||||
black:
|
black:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- uses: psf/black@stable
|
- uses: psf/black@stable
|
||||||
with:
|
with:
|
||||||
options: "--check --diff"
|
options: "--check --diff"
|
||||||
@ -17,8 +17,8 @@ jobs:
|
|||||||
isort:
|
isort:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- uses: actions/setup-python@v2
|
- uses: actions/setup-python@v3
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
- uses: isort/isort-action@master
|
- uses: isort/isort-action@master
|
||||||
|
2
.github/workflows/push-docker-image.yaml
vendored
2
.github/workflows/push-docker-image.yaml
vendored
@ -14,7 +14,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
|
135
.github/workflows/run-tests.yaml
vendored
135
.github/workflows/run-tests.yaml
vendored
@ -6,57 +6,32 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
convert-model:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }}
|
|
||||||
timeout-minutes: 15
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: 3.9
|
|
||||||
- name: Cache dependencies
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: Key-v1-py3.9-${{ hashFiles('setup.cfg') }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install .
|
|
||||||
- name: Delete any test models older than 1 week
|
|
||||||
run: |
|
|
||||||
python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
|
|
||||||
- name: Delete previous version of this model, if exists
|
|
||||||
run: |
|
|
||||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
|
||||||
python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
|
|
||||||
repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true
|
|
||||||
- name: Convert model and push to hub
|
|
||||||
run: |
|
|
||||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
|
||||||
python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \
|
|
||||||
--output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \
|
|
||||||
--resize_token_embeddings 50000
|
|
||||||
|
|
||||||
run-tests:
|
run-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: convert-model
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [ 3.7, 3.8, 3.9 ]
|
include:
|
||||||
|
- { model: 'bigscience/bloom-560m', python-version: '3.8' }
|
||||||
|
- { model: 'bigscience/bloom-560m', python-version: '3.9' }
|
||||||
|
- { model: 'bigscience/bloom-560m', python-version: '3.10' }
|
||||||
|
- { model: 'bigscience/bloom-560m', python-version: '3.11' }
|
||||||
|
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
|
||||||
|
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
timeout-minutes: 15
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- name: Increase swap space
|
||||||
|
uses: pierotofy/set-swap-space@master
|
||||||
|
with:
|
||||||
|
swap-size-gb: 10
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v3
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pip
|
path: ~/.cache/pip
|
||||||
key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
|
key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
|
||||||
@ -66,47 +41,77 @@ jobs:
|
|||||||
pip install .[dev]
|
pip install .[dev]
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
export MODEL_NAME="${{ matrix.model }}"
|
||||||
export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG
|
export REF_NAME="${{ matrix.model }}"
|
||||||
export REF_NAME=bigscience/bloom-560m
|
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
|
||||||
|
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
|
||||||
|
|
||||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
|
# [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
|
||||||
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
|
|
||||||
--torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log &
|
|
||||||
SERVER1_PID=$!
|
|
||||||
|
|
||||||
sleep 5 # wait for the first server to initialize DHT
|
bash -c 'while true; do free -h && sleep 30s; done' &
|
||||||
|
RAM_WATCH_PID=$!
|
||||||
|
|
||||||
|
# [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
|
||||||
|
|
||||||
|
python -m petals.cli.run_dht \
|
||||||
|
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
|
||||||
|
BOOTSTRAP_PID=$!
|
||||||
|
|
||||||
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
|
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
|
||||||
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
|
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
|
||||||
|
|
||||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
|
sleep 5 # wait for DHT init
|
||||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log &
|
|
||||||
|
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
|
||||||
|
--mean_balance_check_period 10 \
|
||||||
|
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
|
||||||
|
SERVER1_PID=$!
|
||||||
|
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
|
||||||
|
|
||||||
|
sleep 10 # wait for the 1st server to choose blocks
|
||||||
|
|
||||||
|
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
|
||||||
|
--identity_path tests/server2.id \
|
||||||
|
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
|
||||||
SERVER2_PID=$!
|
SERVER2_PID=$!
|
||||||
|
|
||||||
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
|
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
|
||||||
|
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
|
||||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \
|
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
|
||||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
|
|
||||||
SERVER3_PID=$!
|
SERVER3_PID=$!
|
||||||
|
# ^-- chunking test
|
||||||
|
|
||||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \
|
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
|
||||||
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
|
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
|
||||||
SERVER4_PID=$!
|
SERVER4_PID=$!
|
||||||
|
# ^-- tensor parallelism test (not compatible with adapters yet)
|
||||||
|
|
||||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
|
sleep 5 # wait for the log files to appear
|
||||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log &
|
|
||||||
SERVER5_PID=$!
|
|
||||||
|
|
||||||
tail -n 100 -f server*.log &
|
tail -n 100 -f bootstrap.log server*.log &
|
||||||
LOGGER_PID=$!
|
LOGGER_PID=$!
|
||||||
sleep 30 # wait for servers to download layers
|
|
||||||
|
|
||||||
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
|
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
|
||||||
|
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
|
||||||
|
|
||||||
|
# [Step 3] Run PyTest
|
||||||
|
|
||||||
pytest tests --durations=0 --durations-min=1.0 -v
|
pytest tests --durations=0 --durations-min=1.0 -v
|
||||||
|
|
||||||
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests
|
# [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
|
||||||
|
|
||||||
kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID $LOGGER_PID
|
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
|
||||||
|
--seq_len 3
|
||||||
|
python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
|
||||||
|
--seq_len 3 --batch_size 3 --n_steps 1
|
||||||
|
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
|
||||||
|
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls
|
||||||
|
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
|
||||||
|
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
|
||||||
|
|
||||||
|
# [Step 5] Clean up
|
||||||
|
|
||||||
|
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests
|
||||||
|
|
||||||
|
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
|
||||||
echo "Done!"
|
echo "Done!"
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -126,3 +126,5 @@ dmypy.json
|
|||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
.idea/
|
||||||
|
@ -17,7 +17,7 @@ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -
|
|||||||
bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
|
bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
|
||||||
ENV PATH="/opt/conda/bin:${PATH}"
|
ENV PATH="/opt/conda/bin:${PATH}"
|
||||||
|
|
||||||
RUN conda install python~=3.10 pip && \
|
RUN conda install python~=3.10.12 pip && \
|
||||||
pip install --no-cache-dir "torch>=1.12" && \
|
pip install --no-cache-dir "torch>=1.12" && \
|
||||||
conda clean --all && rm -rf ~/.cache/pip
|
conda clean --all && rm -rf ~/.cache/pip
|
||||||
|
|
||||||
|
293
README.md
293
README.md
@ -1,154 +1,241 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://i.imgur.com/7eR7Pan.png" width="400"><br>
|
<img src="https://i.imgur.com/7eR7Pan.png" width="400"><br>
|
||||||
Run 100B+ language models at home, BitTorrent-style.<br>
|
Run large language models at home, BitTorrent-style.<br>
|
||||||
Fine-tuning and inference up to 10x faster than offloading<br><br>
|
Fine-tuning and inference <a href="https://github.com/bigscience-workshop/petals#benchmarks">up to 10x faster</a> than offloading
|
||||||
|
<br><br>
|
||||||
|
<a href="https://pypi.org/project/petals/"><img src="https://img.shields.io/pypi/v/petals.svg?color=green"></a>
|
||||||
|
<a href="https://discord.gg/tfHfe8B34k"><img src="https://img.shields.io/discord/865254854262652969?label=discord&logo=discord&logoColor=white"></a>
|
||||||
|
<br>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Generate text using distributed BLOOM and fine-tune it for your own tasks:
|
Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks — right from your desktop computer or Google Colab:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from petals import DistributedBloomForCausalLM
|
from transformers import AutoTokenizer
|
||||||
|
from petals import AutoDistributedModelForCausalLM
|
||||||
|
|
||||||
# Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet
|
model_name = "stabilityai/StableBeluga2"
|
||||||
model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune")
|
# You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf",
|
||||||
|
# repos with LLaMA-65B, "bigscience/bloom", or "bigscience/bloomz"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
|
||||||
|
# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet
|
||||||
|
|
||||||
inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
|
inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
|
||||||
outputs = model.generate(inputs, max_new_tokens=5)
|
outputs = model.generate(inputs, max_new_tokens=5)
|
||||||
print(tokenizer.decode(remote_outputs[0])) # A cat sat on a mat...
|
print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
|
||||||
|
|
||||||
# Training (updates only prompts or adapters hosted locally)
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters())
|
|
||||||
for input_ids, labels in data_loader:
|
|
||||||
outputs = model.forward(input_ids)
|
|
||||||
loss = cross_entropy(outputs.logits, labels)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
🚀 <b><a href="https://colab.research.google.com/drive/1Ervk6HPNS6AYVr3xVdQnY5a-TjjmLCdQ?usp=sharing">Try now in Colab</a></b>
|
🚀 <b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Connect your own GPU and increase Petals capacity:
|
🦙 **Want to run LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
|
||||||
|
|
||||||
|
📋 **Terms of use.** Make sure you follow the model license (see [LLaMA 2](https://bit.ly/llama2-license), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2/blob/main/LICENSE.txt), [LLaMA](https://bit.ly/llama-license), and [BLOOM](https://bit.ly/bloom-license)).
|
||||||
|
|
||||||
|
🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
|
||||||
|
|
||||||
|
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
|
||||||
|
|
||||||
|
### Connect your GPU and increase Petals capacity
|
||||||
|
|
||||||
|
Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models!
|
||||||
|
|
||||||
|
🐍 **Linux + Anaconda.** Run these commands:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# In an Anaconda env
|
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||||
(conda) $ conda install pytorch cudatoolkit=11.3 -c pytorch
|
pip install git+https://github.com/bigscience-workshop/petals
|
||||||
(conda) $ pip install git+https://github.com/bigscience-workshop/petals
|
python -m petals.cli.run_server stabilityai/StableBeluga2
|
||||||
(conda) $ python -m petals.cli.run_server bigscience/bloom-petals
|
|
||||||
|
|
||||||
# Or using a GPU-enabled Docker image
|
|
||||||
sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
|
|
||||||
python -m petals.cli.run_server bigscience/bloom-petals
|
|
||||||
```
|
```
|
||||||
|
|
||||||
💬 If you have any issues or feedback, please join [our Discord server](https://discord.gg/D9MwApKgWa)!
|
🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
|
||||||
|
|
||||||
Check out more tutorials:
|
🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image:
|
||||||
|
|
||||||
- Training a personified chatbot: [notebook](./examples/prompt-tuning-personachat.ipynb)
|
```bash
|
||||||
- Fine-tuning BLOOM for text semantic classification: [notebook](./examples/prompt-tuning-sst2.ipynb)
|
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
|
||||||
- Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
|
python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2
|
||||||
- Running a custom foundation model: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
|
```
|
||||||
|
|
||||||
|
These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
|
||||||
|
|
||||||
|
🦙 **Want to host LLaMA 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE
|
||||||
|
```
|
||||||
|
|
||||||
|
💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
|
||||||
|
|
||||||
|
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
|
||||||
|
|
||||||
|
🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
|
||||||
|
|
||||||
|
### Check out tutorials, examples, and more
|
||||||
|
|
||||||
|
Basic tutorials:
|
||||||
|
|
||||||
|
- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
|
||||||
|
- Prompt-tune LLaMA-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
|
||||||
|
- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
|
||||||
|
|
||||||
|
Useful tools and advanced guides:
|
||||||
|
|
||||||
|
- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
|
||||||
|
- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
|
||||||
|
- Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
|
||||||
|
- Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
|
||||||
|
|
||||||
|
Learning more:
|
||||||
|
|
||||||
|
- Frequently asked questions: [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions)
|
||||||
|
- In-depth system description: [paper](https://arxiv.org/abs/2209.01188)
|
||||||
|
|
||||||
## How does it work?
|
## How does it work?
|
||||||
|
|
||||||
- Petals runs large language models like BLOOM-176B **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning.
|
- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
|
||||||
- Inference runs at ≈ 1 sec per step (token) — 10x faster than possible with offloading, enough for chatbots and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
|
- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
|
||||||
- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. You get the comforts of an API with the flexibility of PyTorch.
|
- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://i.imgur.com/RTYF3yW.png" width="800">
|
<img src="https://i.imgur.com/RTYF3yW.png" width="800">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
|
📚 <b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
|
||||||
|
|
||||||
📜 <b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
|
📜 <b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
### 🔒 Privacy and security
|
|
||||||
|
|
||||||
The Petals public swarm is designed for research and academic use. **Please do not use the public swarm to process sensitive data.** We ask for that because it is an open network, and it is technically possible for peers serving model layers to recover input data and model outputs or modify them in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process your data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
|
|
||||||
|
|
||||||
### 📋 Model's terms of use
|
|
||||||
|
|
||||||
Before building your own application that runs a language model with Petals, please check out the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
|
|
||||||
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
1. **What's the motivation for people to host model layers in the public swarm?**
|
|
||||||
|
|
||||||
People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded).
|
|
||||||
|
|
||||||
Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards.
|
|
||||||
|
|
||||||
2. **Why is the platform named "Petals"?**
|
|
||||||
|
|
||||||
"Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model — [BLOOM](https://huggingface.co/bigscience/bloom).
|
|
||||||
|
|
||||||
While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future.
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Here's how to install Petals with conda:
|
Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/distribution) on Linux:
|
||||||
```
|
|
||||||
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
|
```bash
|
||||||
|
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||||
pip install git+https://github.com/bigscience-workshop/petals
|
pip install git+https://github.com/bigscience-workshop/petals
|
||||||
```
|
```
|
||||||
|
|
||||||
This script uses Anaconda to install cuda-enabled PyTorch.
|
If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
|
||||||
If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution).
|
|
||||||
If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/).
|
|
||||||
If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
|
|
||||||
|
|
||||||
__System requirements:__ Petals only supports Linux for now. If you don't have a Linux machine, consider running Petals in Docker (see our [image](https://hub.docker.com/r/learningathome/petals)) or, in case of Windows, in WSL2 ([read more](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)). CPU is enough to run a client, but you probably need a GPU to run a server efficiently.
|
See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client).
|
||||||
|
|
||||||
## 🛠️ Development
|
## Benchmarks
|
||||||
|
|
||||||
Petals uses pytest with a few plugins. To install them, run:
|
The benchmarks below are for BLOOM-176B:
|
||||||
|
|
||||||
```python
|
<table align="center">
|
||||||
git clone https://github.com/bigscience-workshop/petals.git && cd petals
|
<tr>
|
||||||
pip install -e .[dev]
|
<th colspan="2">Network</th>
|
||||||
|
<th colspan="2">Single-batch inference<br>(steps/s)</th>
|
||||||
|
<th colspan="2">Parallel forward<br>(tokens/s)</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th rowspan="2">Bandwidth</th>
|
||||||
|
<th rowspan="2">Round-trip<br>latency</th>
|
||||||
|
<th colspan="2">Sequence length</th>
|
||||||
|
<th colspan="2">Batch size</th>
|
||||||
|
</tr>
|
||||||
|
<tr align="center">
|
||||||
|
<td>128</td>
|
||||||
|
<td>2048</td>
|
||||||
|
<td>1</td>
|
||||||
|
<td>64</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th colspan="6">Offloading, max. possible speed on 1x A100 <sup>1</sup></th>
|
||||||
|
</tr>
|
||||||
|
<tr align="center">
|
||||||
|
<td>256 Gbit/s</td>
|
||||||
|
<td></td>
|
||||||
|
<td>0.18</td>
|
||||||
|
<td>0.18</td>
|
||||||
|
<td>2.7</td>
|
||||||
|
<td>170.3</td>
|
||||||
|
</tr>
|
||||||
|
<tr align="center">
|
||||||
|
<td>128 Gbit/s</td>
|
||||||
|
<td></td>
|
||||||
|
<td>0.09</td>
|
||||||
|
<td>0.09</td>
|
||||||
|
<td>2.4</td>
|
||||||
|
<td>152.8</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th colspan="6">Petals on 14 heterogeneous servers across Europe and North America <sup>2</sup></th>
|
||||||
|
</tr>
|
||||||
|
<tr align="center">
|
||||||
|
<td colspan="2">Real world</td>
|
||||||
|
<td>0.83</td>
|
||||||
|
<td>0.79</td>
|
||||||
|
<td>32.6</td>
|
||||||
|
<td>179.4</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th colspan="6">Petals on 3 servers, with one A100 each <sup>3</sup></th>
|
||||||
|
</tr>
|
||||||
|
<tr align="center">
|
||||||
|
<td>1 Gbit/s</td>
|
||||||
|
<td>< 5 ms</td>
|
||||||
|
<td>1.71</td>
|
||||||
|
<td>1.54</td>
|
||||||
|
<td>70.0</td>
|
||||||
|
<td>253.6</td>
|
||||||
|
</tr>
|
||||||
|
<tr align="center">
|
||||||
|
<td>100 Mbit/s</td>
|
||||||
|
<td>< 5 ms</td>
|
||||||
|
<td>1.66</td>
|
||||||
|
<td>1.49</td>
|
||||||
|
<td>56.4</td>
|
||||||
|
<td>182.0</td>
|
||||||
|
</tr>
|
||||||
|
<tr align="center">
|
||||||
|
<td>100 Mbit/s</td>
|
||||||
|
<td>100 ms</td>
|
||||||
|
<td>1.23</td>
|
||||||
|
<td>1.11</td>
|
||||||
|
<td>19.7</td>
|
||||||
|
<td>112.2</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
<sup>1</sup> **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch.
|
||||||
|
|
||||||
|
<sup>2</sup> **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 100–1000 Mbit/s. 4 servers operate from under firewalls.
|
||||||
|
|
||||||
|
<sup>3</sup> **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU.
|
||||||
|
|
||||||
|
We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
|
||||||
|
|
||||||
|
## 🛠️ Contributing
|
||||||
|
|
||||||
|
Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
|
||||||
|
|
||||||
|
## 📜 Citation
|
||||||
|
|
||||||
|
Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
|
||||||
|
[Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
|
||||||
|
_arXiv preprint arXiv:2209.01188,_ 2022.
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{borzunov2022petals,
|
||||||
|
title = {Petals: Collaborative Inference and Fine-tuning of Large Models},
|
||||||
|
author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Ryabinin, Max and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},
|
||||||
|
journal = {arXiv preprint arXiv:2209.01188},
|
||||||
|
year = {2022},
|
||||||
|
url = {https://arxiv.org/abs/2209.01188}
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
To run minimalistic tests, you need to make a local swarm with a small model and some servers. You may find more information about how local swarms work and how to run them in [this tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export MODEL_NAME=bloom-testing/test-bloomd-560m-main
|
|
||||||
|
|
||||||
python -m petals.cli.run_server $MODEL_NAME --block_indices 0:12 \
|
|
||||||
--identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --new_swarm &> server1.log &
|
|
||||||
sleep 5 # wait for the first server to initialize DHT
|
|
||||||
|
|
||||||
python -m petals.cli.run_server $MODEL_NAME --block_indices 12:24 \
|
|
||||||
--initial_peers SEE_THE_OUTPUT_OF_THE_1ST_PEER &> server2.log &
|
|
||||||
|
|
||||||
tail -f server1.log server2.log # view logs for both servers
|
|
||||||
```
|
|
||||||
|
|
||||||
Then launch pytest:
|
|
||||||
|
|
||||||
```
|
|
||||||
export MODEL_NAME=bloom-testing/test-bloomd-560m-main REF_NAME=bigscience/bloom-560m
|
|
||||||
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
|
|
||||||
PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
|
|
||||||
```
|
|
||||||
|
|
||||||
After you're done, you can terminate the servers and ensure that no zombie processes are left with `pkill -f petals.cli.run_server && pkill -f p2p`.
|
|
||||||
|
|
||||||
The automated tests use a more complex server configuration that can be found [here](https://github.com/bigscience-workshop/petals/blob/main/.github/workflows/run-tests.yaml).
|
|
||||||
|
|
||||||
### Code style
|
|
||||||
|
|
||||||
We use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests.
|
|
||||||
Before committing your code, simply run `black . && isort .` and you will be fine.
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
This project is a part of the <a href="https://bigscience.huggingface.co/">BigScience</a> research workshop.
|
This project is a part of the <a href="https://bigscience.huggingface.co/">BigScience</a> research workshop.
|
||||||
</p>
|
</p>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://petals.ml/bigscience.png" width="150">
|
<img src="https://petals.dev/bigscience.png" width="150">
|
||||||
</p>
|
</p>
|
||||||
|
75
benchmarks/benchmark_forward.py
Executable file
75
benchmarks/benchmark_forward.py
Executable file
@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
|
from petals import AutoDistributedModel
|
||||||
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||||
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||||
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||||
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||||
|
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||||
|
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
|
||||||
|
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.n_processes == "n_gpus":
|
||||||
|
args.n_processes = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
args.n_processes = int(args.n_processes)
|
||||||
|
|
||||||
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||||
|
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||||
|
for proc in processes:
|
||||||
|
proc.start()
|
||||||
|
for proc in processes:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||||
|
logger.info(f"Final result: {speed=:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def benchmark_forward(process_idx, args, result_pipe):
|
||||||
|
model = AutoDistributedModel.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||||
|
)
|
||||||
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
step_times = []
|
||||||
|
for step in range(args.warmup_steps + args.n_steps):
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
||||||
|
h = model(input_ids)
|
||||||
|
# We don't use model.lm_head
|
||||||
|
logger.info(f"{process_idx=} Fwd end")
|
||||||
|
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
step_times.append(perf_counter() - start_time)
|
||||||
|
speed = input_ids.numel() / np.mean(step_times)
|
||||||
|
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||||
|
|
||||||
|
result_pipe.send(speed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
72
benchmarks/benchmark_inference.py
Executable file
72
benchmarks/benchmark_inference.py
Executable file
@ -0,0 +1,72 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from petals import AutoDistributedModelForCausalLM
|
||||||
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||||
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||||
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||||
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||||
|
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.n_processes == "n_gpus":
|
||||||
|
args.n_processes = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
args.n_processes = int(args.n_processes)
|
||||||
|
|
||||||
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||||
|
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||||
|
for proc in processes:
|
||||||
|
proc.start()
|
||||||
|
for proc in processes:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
||||||
|
logger.info(f"Final result: {speed=:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def benchmark_inference(process_idx, args, result_pipe):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
||||||
|
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
|
||||||
|
|
||||||
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
|
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||||
|
)
|
||||||
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||||
|
|
||||||
|
result = ""
|
||||||
|
step_times = []
|
||||||
|
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
|
||||||
|
for step in range(args.seq_len):
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
outputs = model.generate(max_new_tokens=1, session=sess)
|
||||||
|
result += tokenizer.decode(outputs[0])
|
||||||
|
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
step_times.append(perf_counter() - start_time)
|
||||||
|
speed = 1 / np.mean(step_times)
|
||||||
|
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
||||||
|
|
||||||
|
result_pipe.send(speed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
107
benchmarks/benchmark_training.py
Executable file
107
benchmarks/benchmark_training.py
Executable file
@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
from time import perf_counter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
|
from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
|
||||||
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Model")
|
||||||
|
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
|
||||||
|
parser.add_argument("--task", type=str, default="cls", help="Training task type")
|
||||||
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
||||||
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
||||||
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
||||||
|
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
||||||
|
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
|
||||||
|
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
|
||||||
|
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.task in ["cls", "causal_lm"]
|
||||||
|
|
||||||
|
if args.n_processes == "n_gpus":
|
||||||
|
args.n_processes = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
args.n_processes = int(args.n_processes)
|
||||||
|
|
||||||
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||||
|
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
||||||
|
for proc in processes:
|
||||||
|
proc.start()
|
||||||
|
for proc in processes:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
|
||||||
|
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_training(process_idx, args, result_pipe):
|
||||||
|
if args.task == "cls":
|
||||||
|
model = AutoDistributedModelForSequenceClassification.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||||
|
tuning_mode="deep_ptune",
|
||||||
|
pre_seq_len=args.pre_seq_len,
|
||||||
|
num_labels=2,
|
||||||
|
)
|
||||||
|
elif args.task == "causal_lm":
|
||||||
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||||
|
tuning_mode="deep_ptune",
|
||||||
|
pre_seq_len=args.pre_seq_len,
|
||||||
|
)
|
||||||
|
model = model.to(args.device)
|
||||||
|
opt = torch.optim.Adam(model.parameters())
|
||||||
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
fwd_times = []
|
||||||
|
bwd_times = []
|
||||||
|
for step in range(args.warmup_steps + args.n_steps):
|
||||||
|
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
|
||||||
|
if args.task == "cls":
|
||||||
|
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
|
||||||
|
else:
|
||||||
|
labels = input_ids
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} {step=} Forward")
|
||||||
|
start_time = perf_counter()
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
fwd_times.append(perf_counter() - start_time)
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} {step=} Backward")
|
||||||
|
start_time = perf_counter()
|
||||||
|
outputs.loss.backward()
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
bwd_times.append(perf_counter() - start_time)
|
||||||
|
|
||||||
|
logger.info(f"{process_idx=} {step=} Optimizer step")
|
||||||
|
opt.step()
|
||||||
|
opt.zero_grad()
|
||||||
|
|
||||||
|
if step >= args.warmup_steps:
|
||||||
|
fwd_speed = input_ids.numel() / np.mean(fwd_times)
|
||||||
|
bwd_speed = input_ids.numel() / np.mean(bwd_times)
|
||||||
|
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
||||||
|
|
||||||
|
result_pipe.send((fwd_speed, bwd_speed))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -11,9 +11,9 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Distributed Bloom for Text Generation using Prompt Tuning\n",
|
"# Distributed Bloom for Text Generation using Prompt Tuning\n",
|
||||||
"\n",
|
"\n",
|
||||||
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
|
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
|
"We will adapt BLOOM for the task of creating a chatbot with a specific personality using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"To use this notebook in Colab:\n",
|
"To use this notebook in Colab:\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -36,8 +36,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install -q git+https://github.com/bigscience-workshop/petals\n",
|
"%pip install -q petals datasets wandb scikit-learn"
|
||||||
"!pip install -q datasets wandb"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -76,7 +75,18 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"MODEL_NAME = \"bigscience/bloom-petals\" # select model you like\n",
|
"# Choose a model you'd like to prompt-tune. We recommend starting with\n",
|
||||||
|
"# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n",
|
||||||
|
"# Once your code is ready, you can switch to full-scale\n",
|
||||||
|
"# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n",
|
||||||
|
"MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
|
||||||
|
"# The latter fine-tunes separate prefixes for each transformer block,\n",
|
||||||
|
"# so prompt-tuning will take more time but yield better results.\n",
|
||||||
|
"# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n",
|
||||||
|
"TUNING_MODE = 'ptune'\n",
|
||||||
|
"\n",
|
||||||
"NUM_PREFIX_TOKENS = 16\n",
|
"NUM_PREFIX_TOKENS = 16\n",
|
||||||
"DEVICE = 'cuda'\n",
|
"DEVICE = 'cuda'\n",
|
||||||
"BATCH_SIZE = 8\n",
|
"BATCH_SIZE = 8\n",
|
||||||
@ -84,8 +94,7 @@
|
|||||||
"WEIGHT_DECAY = 0.0\n",
|
"WEIGHT_DECAY = 0.0\n",
|
||||||
"NUM_SAMPLES = 1000\n",
|
"NUM_SAMPLES = 1000\n",
|
||||||
"SEED = 42\n",
|
"SEED = 42\n",
|
||||||
"MODEL_MAX_LENGTH = 256\n",
|
"MODEL_MAX_LENGTH = 256"
|
||||||
"TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -276,7 +285,7 @@
|
|||||||
" user_phrase = input()\n",
|
" user_phrase = input()\n",
|
||||||
" if len(user_phrase) == 0:\n",
|
" if len(user_phrase) == 0:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n",
|
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n",
|
||||||
" while True:\n",
|
" while True:\n",
|
||||||
" outputs = model.generate(\n",
|
" outputs = model.generate(\n",
|
||||||
" inputs,\n",
|
" inputs,\n",
|
||||||
|
@ -3,17 +3,19 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "a07e0f5e",
|
"id": "a07e0f5e",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "a07e0f5e"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<div>\n",
|
"<div>\n",
|
||||||
"<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
|
"<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
|
||||||
"</div>\n",
|
"</div>\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Distributed Bloom for Text Classification using Prompt Tuning\n",
|
"# Distributed LLaMA for Text Classification using Prompt Tuning\n",
|
||||||
"\n",
|
"\n",
|
||||||
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
|
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"We will adapt the BLOOM model for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n",
|
"We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"To use this notebook in Colab:\n",
|
"To use this notebook in Colab:\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -24,7 +26,9 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "a3f8526f",
|
"id": "a3f8526f",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "a3f8526f"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"First, we have to prepare all dependencies."
|
"First, we have to prepare all dependencies."
|
||||||
]
|
]
|
||||||
@ -33,18 +37,22 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "73bbc648",
|
"id": "73bbc648",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "73bbc648"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install -q git+https://github.com/bigscience-workshop/petals\n",
|
"%pip install -q datasets wandb scikit-learn\n",
|
||||||
"!pip install -q datasets wandb"
|
"%pip install -q git+https://github.com/bigscience-workshop/petals@main"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "b4ab6ca7",
|
"id": "b4ab6ca7",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "b4ab6ca7"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
@ -58,15 +66,19 @@
|
|||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"from torch.optim import AdamW\n",
|
"from torch.optim import AdamW\n",
|
||||||
"from torch.utils.data import DataLoader\n",
|
"from torch.utils.data import DataLoader\n",
|
||||||
"from transformers import BloomTokenizerFast, get_scheduler\n",
|
"from transformers import LlamaTokenizer, get_scheduler, set_seed\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from petals import DistributedBloomForSequenceClassification"
|
"from petals import DistributedLlamaForSequenceClassification\n",
|
||||||
|
"\n",
|
||||||
|
"set_seed(0)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1bf07b5d",
|
"id": "1bf07b5d",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "1bf07b5d"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Let's set some hyperparameters for training:"
|
"Let's set some hyperparameters for training:"
|
||||||
]
|
]
|
||||||
@ -75,50 +87,66 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "f04ba4d2",
|
"id": "f04ba4d2",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "f04ba4d2"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"MODEL_NAME = \"bigscience/bloom-petals\" # select model you like\n",
|
"MODEL_NAME = \"enoch/llama-65b-hf\"\n",
|
||||||
"NUM_PREFIX_TOKENS = 16\n",
|
"\n",
|
||||||
|
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
|
||||||
|
"# The latter fine-tunes separate prefixes for each transformer block,\n",
|
||||||
|
"# so prompt-tuning will take more time but yield better results.\n",
|
||||||
|
"# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n",
|
||||||
|
"TUNING_MODE = 'ptune'\n",
|
||||||
|
"\n",
|
||||||
|
"NUM_PREFIX_TOKENS = 8\n",
|
||||||
"DEVICE = 'cuda'\n",
|
"DEVICE = 'cuda'\n",
|
||||||
"BATCH_SIZE = 16\n",
|
"BATCH_SIZE = 32\n",
|
||||||
"LR = 1e-2\n",
|
"LR = 1e-2\n",
|
||||||
"WEIGHT_DECAY = 0.0\n",
|
"WEIGHT_DECAY = 0.0\n",
|
||||||
"NUM_EPOCHS = 3\n",
|
"NUM_EPOCHS = 3\n",
|
||||||
"SEED = 42\n",
|
"SEED = 42\n",
|
||||||
"MODEL_MAX_LENGTH = 64\n",
|
"MODEL_MAX_LENGTH = 64"
|
||||||
"TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "d38316bd",
|
"id": "d38316bd",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "d38316bd"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Prepare tokenizer and distributed model, connect it to servers."
|
"Here, we prepare tokenizer and distributed model and connect it to the public swarm."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "03c6e53e",
|
"id": "03c6e53e",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "03c6e53e"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
|
"tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\n",
|
||||||
"tokenizer.padding_side = 'right'\n",
|
"tokenizer.padding_side = 'right'\n",
|
||||||
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
|
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
|
||||||
"model = DistributedBloomForSequenceClassification.from_pretrained(\n",
|
"tokenizer.pad_token = tokenizer.unk_token\n",
|
||||||
|
"model = DistributedLlamaForSequenceClassification.from_pretrained(\n",
|
||||||
" MODEL_NAME,\n",
|
" MODEL_NAME,\n",
|
||||||
" pre_seq_len=NUM_PREFIX_TOKENS,\n",
|
" pre_seq_len=NUM_PREFIX_TOKENS,\n",
|
||||||
" tuning_mode=TUNING_MODE\n",
|
" tuning_mode=TUNING_MODE\n",
|
||||||
").to(DEVICE)"
|
").float().to(DEVICE)\n",
|
||||||
|
"model.config.pad_token_id = tokenizer.pad_token_id"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "042e3786",
|
"id": "042e3786",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "042e3786"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset."
|
"Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset."
|
||||||
]
|
]
|
||||||
@ -127,7 +155,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "9c44d516",
|
"id": "9c44d516",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "9c44d516"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"task = 'sst2'\n",
|
"task = 'sst2'\n",
|
||||||
@ -135,7 +165,7 @@
|
|||||||
"dataset = load_dataset(\"glue\", task)\n",
|
"dataset = load_dataset(\"glue\", task)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def preprocess_function(examples):\n",
|
"def preprocess_function(examples):\n",
|
||||||
" return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True)\n",
|
" return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True, return_token_type_ids=False)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
|
"tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
|
||||||
"tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n",
|
"tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n",
|
||||||
@ -152,16 +182,20 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "2a3f3590",
|
"id": "2a3f3590",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "2a3f3590"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library."
|
"To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "1e1812be",
|
"id": "1e1812be",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "1e1812be"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"metric = load_metric('glue', task)\n",
|
"metric = load_metric('glue', task)\n",
|
||||||
@ -170,7 +204,7 @@
|
|||||||
" model.eval()\n",
|
" model.eval()\n",
|
||||||
" for batch in dataloader:\n",
|
" for batch in dataloader:\n",
|
||||||
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
||||||
" \n",
|
"\n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad():\n",
|
||||||
" outputs = model(**batch)\n",
|
" outputs = model(**batch)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -184,16 +218,20 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "ef4323fd",
|
"id": "ef4323fd",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "ef4323fd"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Before setting up optimizers, check the model parameters that will be trained."
|
"Before setting up optimizers, let's check the model parameters that will be trained."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "9cc0ba34",
|
"id": "9cc0ba34",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "9cc0ba34"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"for n, p in model.named_parameters():\n",
|
"for n, p in model.named_parameters():\n",
|
||||||
@ -204,29 +242,35 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "59cffce7",
|
"id": "59cffce7",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "59cffce7"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
|
"The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "ef9bf344",
|
"id": "ef9bf344",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "ef9bf344"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"lr_scheduler = get_scheduler(\n",
|
"lr_scheduler = get_scheduler(\n",
|
||||||
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "423c56d5",
|
"id": "423c56d5",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "423c56d5"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Let's initialize wandb for logging and start the training loop!"
|
"Let's initialize wandb for logging and start the training loop!"
|
||||||
]
|
]
|
||||||
@ -235,7 +279,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "d9e46807",
|
"id": "d9e46807",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "d9e46807"
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"wandb.init(\n",
|
"wandb.init(\n",
|
||||||
@ -251,20 +297,24 @@
|
|||||||
" }\n",
|
" }\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"scaler = torch.cuda.amp.GradScaler()\n",
|
||||||
|
"\n",
|
||||||
"for epoch in range(NUM_EPOCHS):\n",
|
"for epoch in range(NUM_EPOCHS):\n",
|
||||||
|
" model.train()\n",
|
||||||
" for batch in tqdm(train_dataloader):\n",
|
" for batch in tqdm(train_dataloader):\n",
|
||||||
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
||||||
"\n",
|
"\n",
|
||||||
" model.train()\n",
|
" with torch.autocast(device_type=DEVICE, dtype=torch.float16):\n",
|
||||||
" outputs = model(**batch)\n",
|
" outputs = model(**batch)\n",
|
||||||
" loss = outputs.loss\n",
|
" loss = outputs.loss\n",
|
||||||
" loss.backward()\n",
|
" scaler.scale(loss).backward()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" optimizer.step()\n",
|
" scaler.step(optimizer)\n",
|
||||||
|
" scaler.update()\n",
|
||||||
" lr_scheduler.step()\n",
|
" lr_scheduler.step()\n",
|
||||||
" optimizer.zero_grad()\n",
|
" optimizer.zero_grad()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" wandb.log({\"Train Loss\": loss})\n",
|
" wandb.log({\"Train Loss\": loss.detach()})\n",
|
||||||
"\n",
|
"\n",
|
||||||
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
|
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
|
||||||
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
||||||
@ -273,183 +323,26 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "51770911",
|
"id": "51770911",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"id": "51770911"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Our model have been trained!"
|
"Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](https://health.petals.dev/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "1bbf014f",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Beyond soft-prompt tuning\n",
|
|
||||||
"\n",
|
|
||||||
"Let's try to tune model using adapters in the middle of the model."
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "3bea4391",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [],
|
||||||
"class BloomBasedClassifier(nn.Module):\n",
|
"metadata": {
|
||||||
" def __init__(\n",
|
"collapsed": false
|
||||||
" self,\n",
|
}
|
||||||
" model,\n",
|
|
||||||
" intermediate_size: int = 32,\n",
|
|
||||||
" num_classes: int = 2,\n",
|
|
||||||
" adapter_layer_position: int = 6,\n",
|
|
||||||
" head_layer_position: int = 10\n",
|
|
||||||
" ):\n",
|
|
||||||
" super().__init__()\n",
|
|
||||||
" self.distributed_layers = model.transformer.h\n",
|
|
||||||
"\n",
|
|
||||||
" self.hidden_size = model.config.hidden_size\n",
|
|
||||||
" self.intermediate_size = intermediate_size\n",
|
|
||||||
" self.num_classes = num_classes\n",
|
|
||||||
" self.adapter_layer_position = adapter_layer_position\n",
|
|
||||||
" self.head_layer_position = head_layer_position\n",
|
|
||||||
" \n",
|
|
||||||
" self.adapter = nn.Sequential(\n",
|
|
||||||
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
|
|
||||||
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
|
|
||||||
" )\n",
|
|
||||||
" self.head = nn.Sequential(\n",
|
|
||||||
" nn.LayerNorm(self.hidden_size),\n",
|
|
||||||
" nn.Linear(self.hidden_size, self.num_classes),\n",
|
|
||||||
" )\n",
|
|
||||||
" \n",
|
|
||||||
" def forward(self, embeddings):\n",
|
|
||||||
" before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
|
|
||||||
" after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n",
|
|
||||||
" \n",
|
|
||||||
" hidden_states = before_layers(embeddings)\n",
|
|
||||||
" hidden_states = self.adapter(hidden_states)\n",
|
|
||||||
" hidden_states = after_layers(hidden_states)\n",
|
|
||||||
" pooled_states = torch.mean(hidden_states, dim=1)\n",
|
|
||||||
" return self.head(pooled_states)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "15299620",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Clear model and device memory."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "aa27b168",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"del model, optimizer, lr_scheduler\n",
|
|
||||||
"torch.cuda.empty_cache()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "5406390f",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Create new model with adapters."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "a251db80",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"INTERMEDIATE_SIZE = 32\n",
|
|
||||||
"ADAPTER_LAYER_POSITION = 6\n",
|
|
||||||
"HEAD_LAYER_POSITION = 10"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "3578df3a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
|
|
||||||
"\n",
|
|
||||||
"cls_model = BloomBasedClassifier(\n",
|
|
||||||
" model,\n",
|
|
||||||
" intermediate_size=INTERMEDIATE_SIZE,\n",
|
|
||||||
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
|
|
||||||
" head_layer_position=HEAD_LAYER_POSITION,\n",
|
|
||||||
")\n",
|
|
||||||
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
||||||
"\n",
|
|
||||||
"lr_scheduler = get_scheduler(\n",
|
|
||||||
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "a40468b9",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"And start training our new adapted model."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "ed051a5d",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"wandb.init(\n",
|
|
||||||
" project=\"bloom_based_cls-sst-2\",\n",
|
|
||||||
" config={\n",
|
|
||||||
" \"num_epochs\": NUM_EPOCHS,\n",
|
|
||||||
" \"batch_size\": BATCH_SIZE,\n",
|
|
||||||
" \"learning_rate\": LR,\n",
|
|
||||||
" \"weight_decay\": WEIGHT_DECAY,\n",
|
|
||||||
" \"model_name\": MODEL_NAME,\n",
|
|
||||||
" \"seed\": SEED,\n",
|
|
||||||
" \"intermediate_size\": INTERMEDIATE_SIZE,\n",
|
|
||||||
" \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n",
|
|
||||||
" \"head_layer_position\": HEAD_LAYER_POSITION,\n",
|
|
||||||
" }\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"for epoch in range(NUM_EPOCHS):\n",
|
|
||||||
" for batch in tqdm(train_dataloader):\n",
|
|
||||||
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
||||||
"\n",
|
|
||||||
" cls_model.train()\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
|
|
||||||
" outputs = cls_model(embeddings_output)\n",
|
|
||||||
" loss.backward()\n",
|
|
||||||
"\n",
|
|
||||||
" cls_optimizer.step()\n",
|
|
||||||
" lr_scheduler.step()\n",
|
|
||||||
" cls_optimizer.zero_grad()\n",
|
|
||||||
"\n",
|
|
||||||
" wandb.log({\"Train Loss\": loss})\n",
|
|
||||||
"\n",
|
|
||||||
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
|
|
||||||
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3.8.9 64-bit",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
@ -462,13 +355,18 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.9 (default, Apr 13 2022, 08:48:07) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]"
|
"version": "3.8.8"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"colab": {
|
||||||
|
"provenance": [],
|
||||||
|
"gpuType": "T4"
|
||||||
|
},
|
||||||
|
"accelerator": "GPU"
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 5
|
"nbformat_minor": 5
|
||||||
|
@ -15,3 +15,4 @@ line_length = 120
|
|||||||
combine_as_imports = true
|
combine_as_imports = true
|
||||||
combine_star = true
|
combine_star = true
|
||||||
known_local_folder = ["tests", "cli"]
|
known_local_folder = ["tests", "cli"]
|
||||||
|
known_first_party = ["test_utils"]
|
||||||
|
28
setup.cfg
28
setup.cfg
@ -1,8 +1,8 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = petals
|
name = petals
|
||||||
version = 1.0alpha1
|
version = attr: petals.__version__
|
||||||
author = Petals Developers
|
author = Petals Developers
|
||||||
author_email = petals-dev@googlegroups.com
|
author_email = petals-devs@googlegroups.com
|
||||||
description = Easy way to efficiently run 100B+ language models without high-end GPUs
|
description = Easy way to efficiently run 100B+ language models without high-end GPUs
|
||||||
long_description = file: README.md
|
long_description = file: README.md
|
||||||
long_description_content_type = text/markdown
|
long_description_content_type = text/markdown
|
||||||
@ -15,9 +15,9 @@ classifiers =
|
|||||||
Intended Audience :: Science/Research
|
Intended Audience :: Science/Research
|
||||||
License :: OSI Approved :: MIT License
|
License :: OSI Approved :: MIT License
|
||||||
Programming Language :: Python :: 3
|
Programming Language :: Python :: 3
|
||||||
Programming Language :: Python :: 3.7
|
|
||||||
Programming Language :: Python :: 3.8
|
Programming Language :: Python :: 3.8
|
||||||
Programming Language :: Python :: 3.9
|
Programming Language :: Python :: 3.9
|
||||||
|
Programming Language :: Python :: 3.10
|
||||||
Topic :: Scientific/Engineering
|
Topic :: Scientific/Engineering
|
||||||
Topic :: Scientific/Engineering :: Mathematics
|
Topic :: Scientific/Engineering :: Mathematics
|
||||||
Topic :: Scientific/Engineering :: Artificial Intelligence
|
Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||||
@ -29,18 +29,26 @@ classifiers =
|
|||||||
package_dir =
|
package_dir =
|
||||||
= src
|
= src
|
||||||
packages = find:
|
packages = find:
|
||||||
python_requires = >=3.7
|
python_requires = >=3.8
|
||||||
install_requires =
|
install_requires =
|
||||||
torch>=1.12
|
torch>=1.12
|
||||||
bitsandbytes==0.34.0
|
bitsandbytes==0.41.1
|
||||||
accelerate==0.15.0
|
accelerate>=0.20.3,<0.21.0
|
||||||
huggingface-hub==0.11.1
|
huggingface-hub>=0.11.1,<1.0.0
|
||||||
transformers==4.25.1
|
tokenizers>=0.13.3
|
||||||
protobuf>=3.20.3,<4.0dev
|
transformers>=4.31.0,<5.0.0
|
||||||
speedtest-cli==2.1.3
|
speedtest-cli==2.1.3
|
||||||
hivemind==1.1.3
|
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
|
||||||
|
hivemind==1.1.9
|
||||||
|
tensor_parallel==1.0.23
|
||||||
humanfriendly
|
humanfriendly
|
||||||
async-timeout>=4.0.2
|
async-timeout>=4.0.2
|
||||||
|
cpufeature>=0.2.0
|
||||||
|
packaging>=20.9
|
||||||
|
sentencepiece>=0.1.99
|
||||||
|
peft>=0.4.0
|
||||||
|
safetensors>=0.3.1
|
||||||
|
Dijkstar>=2.6.0
|
||||||
|
|
||||||
[options.extras_require]
|
[options.extras_require]
|
||||||
dev =
|
dev =
|
||||||
|
@ -1,6 +1,29 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import transformers
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from petals.client import *
|
from petals.client import *
|
||||||
|
from petals.models import *
|
||||||
|
from petals.utils import *
|
||||||
from petals.utils.logging import initialize_logs as _initialize_logs
|
from petals.utils.logging import initialize_logs as _initialize_logs
|
||||||
|
|
||||||
__version__ = "1.0alpha1"
|
__version__ = "2.0.1.post2"
|
||||||
|
|
||||||
|
|
||||||
|
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
||||||
|
assert (
|
||||||
|
version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
||||||
|
), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def _override_bfloat16_mode_default():
|
||||||
|
if os.getenv("USE_LEGACY_BFLOAT16") is None:
|
||||||
|
hivemind.compression.base.USE_LEGACY_BFLOAT16 = False
|
||||||
|
|
||||||
|
|
||||||
_initialize_logs()
|
_initialize_logs()
|
||||||
|
_override_bfloat16_mode_default()
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
"""
|
|
||||||
Bloom intermediate layer
|
|
||||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
||||||
See commit history for authorship.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch.nn.quantized.dynamic.modules.linear
|
|
||||||
import transformers
|
|
||||||
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
|
|
||||||
|
|
||||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
|
||||||
assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1"
|
|
||||||
|
|
||||||
|
|
||||||
class WrappedBloomBlock(BloomBlock):
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
*args,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
alibi: Optional[torch.Tensor] = None,
|
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
assert attention_mask is None
|
|
||||||
batch_size, seq_length = hidden_states.shape[:2]
|
|
||||||
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
|
||||||
seq_length_with_past = seq_length + past_length
|
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
||||||
if alibi is None:
|
|
||||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
|
||||||
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
|
||||||
return super().forward(
|
|
||||||
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_attn_mask(
|
|
||||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
|
||||||
) -> torch.BoolTensor:
|
|
||||||
# create causal mask
|
|
||||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
|
||||||
combined_attention_mask = None
|
|
||||||
device = attention_mask.device
|
|
||||||
_, src_length = input_shape
|
|
||||||
|
|
||||||
if src_length > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
|
||||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
||||||
combined_attention_mask = (
|
|
||||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_attention_mask
|
|
@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
|
||||||
If necessary, one can rewrite this to implement a different behavior, such as:
|
|
||||||
- loading files from a local data source (e.g. S3)
|
|
||||||
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
|
||||||
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import time
|
|
||||||
from typing import Optional, OrderedDict, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from transformers.modeling_utils import WEIGHTS_NAME
|
|
||||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
||||||
from transformers.utils import get_file_from_repo
|
|
||||||
|
|
||||||
from petals.bloom.block import WrappedBloomBlock
|
|
||||||
from petals.server.block_utils import get_block_size
|
|
||||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
CLIENT_BRANCH = "main"
|
|
||||||
BLOCK_BRANCH_PREFIX = "block_"
|
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_block(
|
|
||||||
converted_model_name_or_path: str,
|
|
||||||
block_index: int,
|
|
||||||
config: Optional[BloomConfig] = None,
|
|
||||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
|
||||||
use_auth_token: Optional[str] = None,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
max_disk_space: Optional[int] = None,
|
|
||||||
) -> WrappedBloomBlock:
|
|
||||||
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = DEFAULT_CACHE_DIR
|
|
||||||
|
|
||||||
block = WrappedBloomBlock(config)
|
|
||||||
state_dict = _load_state_dict(
|
|
||||||
converted_model_name_or_path,
|
|
||||||
block_index,
|
|
||||||
config,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
max_disk_space=max_disk_space,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch_dtype == "auto":
|
|
||||||
with torch.no_grad():
|
|
||||||
for name, param in block.named_parameters():
|
|
||||||
assert name in state_dict, f"{name} not in state dict"
|
|
||||||
param.data = param.data.to(state_dict[name].dtype)
|
|
||||||
else:
|
|
||||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
||||||
block = block.to(dtype=torch_dtype)
|
|
||||||
|
|
||||||
report = block.load_state_dict(state_dict, strict=True)
|
|
||||||
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
|
||||||
return block
|
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict(
|
|
||||||
pretrained_model_name_or_path: str,
|
|
||||||
block_index: int,
|
|
||||||
config: BloomConfig,
|
|
||||||
*,
|
|
||||||
use_auth_token: Optional[str] = None,
|
|
||||||
cache_dir: str,
|
|
||||||
max_disk_space: Optional[int] = None,
|
|
||||||
min_backoff: float = 5,
|
|
||||||
) -> OrderedDict[str, torch.Tensor]:
|
|
||||||
revision = BLOCK_BRANCH_PREFIX + str(block_index)
|
|
||||||
|
|
||||||
# First, try to find the weights locally
|
|
||||||
try:
|
|
||||||
with allow_cache_reads(cache_dir):
|
|
||||||
archive_file = get_file_from_repo(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=WEIGHTS_NAME,
|
|
||||||
revision=revision,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
if archive_file is not None:
|
|
||||||
return torch.load(archive_file, map_location="cpu")
|
|
||||||
except Exception:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
|
||||||
for attempt_no in itertools.count():
|
|
||||||
try:
|
|
||||||
with allow_cache_writes(cache_dir):
|
|
||||||
block_size = get_block_size(config, "disk")
|
|
||||||
free_disk_space_for(
|
|
||||||
pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
|
|
||||||
)
|
|
||||||
|
|
||||||
archive_file = get_file_from_repo(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=WEIGHTS_NAME,
|
|
||||||
revision=revision,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
return torch.load(archive_file, map_location="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
delay = min_backoff * (2**attempt_no)
|
|
||||||
logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
||||||
time.sleep(delay)
|
|
||||||
|
|
||||||
|
|
||||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
@ -1,72 +0,0 @@
|
|||||||
"""
|
|
||||||
PyTorch BLOOM model that implements several memory-efficient modes.
|
|
||||||
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
||||||
See commit history for authorship.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from hivemind import get_logger
|
|
||||||
from torch import nn
|
|
||||||
from transformers import BloomConfig
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
class LMHead(nn.Module):
|
|
||||||
"""
|
|
||||||
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
|
|
||||||
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
|
||||||
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
|
|
||||||
super().__init__()
|
|
||||||
self.word_embeddings = word_embeddings
|
|
||||||
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
|
|
||||||
|
|
||||||
@property
|
|
||||||
def in_features(self) -> int:
|
|
||||||
return self.word_embeddings.num_embeddings
|
|
||||||
|
|
||||||
@property
|
|
||||||
def out_features(self) -> int:
|
|
||||||
return self.word_embeddings.embedding_dim
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.word_embeddings.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
word_embeddings = self.word_embeddings.weight
|
|
||||||
|
|
||||||
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
|
|
||||||
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
|
|
||||||
lm_logits = self.chunked_forward(hidden_states)
|
|
||||||
else:
|
|
||||||
# Switch dtype in case word_embeddings are fp16/bf16
|
|
||||||
hidden_states = hidden_states.to(word_embeddings.dtype)
|
|
||||||
lm_logits = F.linear(hidden_states, word_embeddings)
|
|
||||||
return lm_logits
|
|
||||||
|
|
||||||
def chunked_forward(self, hidden_states):
|
|
||||||
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
|
||||||
chunk_size: provides trade-off between efficiency and extra memory consumption.
|
|
||||||
"""
|
|
||||||
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
|
|
||||||
|
|
||||||
word_embeddings = self.word_embeddings.weight
|
|
||||||
num_embeddings = self.word_embeddings.num_embeddings
|
|
||||||
|
|
||||||
hidden_states = hidden_states.float()
|
|
||||||
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
|
|
||||||
|
|
||||||
for i in range(0, num_embeddings, self.chunk_size):
|
|
||||||
chunk = word_embeddings[i : i + self.chunk_size].float()
|
|
||||||
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
|
|
||||||
return output
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"apply_residual_connection_post_layernorm": false,
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"attention_softmax_in_fp32": true,
|
|
||||||
"bos_token_id": 1,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"hidden_dropout": 0.0,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"layer_norm_epsilon": 1e-05,
|
|
||||||
"masked_softmax_fusion": true,
|
|
||||||
"model_type": "bloom",
|
|
||||||
"n_embed": 14336,
|
|
||||||
"n_layer": 70,
|
|
||||||
"num_attention_heads": 112,
|
|
||||||
"pretraining_tp": 4,
|
|
||||||
"slow_but_exact": false,
|
|
||||||
"transformers_version": "4.20.0.dev0",
|
|
||||||
"use_cache": true,
|
|
||||||
"vocab_size": 250880
|
|
||||||
}
|
|
@ -1,92 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch.backends.quantized
|
|
||||||
import torch.nn as nn
|
|
||||||
import transformers
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from huggingface_hub import Repository
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
|
||||||
from petals.client import DistributedBloomConfig
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
|
||||||
|
|
||||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
|
||||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
|
||||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
|
||||||
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
|
||||||
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
|
||||||
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
|
||||||
parser.add_argument(
|
|
||||||
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
|
||||||
)
|
|
||||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
|
||||||
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
|
||||||
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
|
||||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
|
||||||
|
|
||||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
|
||||||
if os.path.exists(args.output_path) and (
|
|
||||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
|
||||||
):
|
|
||||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
|
||||||
|
|
||||||
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
|
||||||
config = DistributedBloomConfig.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
||||||
)
|
|
||||||
config.dht_prefix = args.output_repo
|
|
||||||
|
|
||||||
model = BloomModel.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
||||||
)
|
|
||||||
if args.resize_token_embeddings:
|
|
||||||
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
|
||||||
model.resize_token_embeddings(args.resize_token_embeddings)
|
|
||||||
config.vocab_size = args.resize_token_embeddings
|
|
||||||
|
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
||||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
||||||
)
|
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
|
||||||
|
|
||||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
||||||
repo.git_pull()
|
|
||||||
|
|
||||||
transformer_blocks = model.h
|
|
||||||
logger.info(
|
|
||||||
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
|
||||||
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
|
||||||
)
|
|
||||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
|
||||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
||||||
with repo.commit(
|
|
||||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
|
||||||
):
|
|
||||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
|
||||||
|
|
||||||
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
|
||||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
||||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
|
||||||
model.h = nn.ModuleList()
|
|
||||||
model.save_pretrained(".")
|
|
||||||
tokenizer.save_pretrained(".")
|
|
||||||
config.save_pretrained(".")
|
|
||||||
|
|
||||||
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
|
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
#################
|
|
||||||
# Parse options #
|
|
||||||
#################
|
|
||||||
|
|
||||||
instructions() {
|
|
||||||
echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
|
|
||||||
echo " -m: model name"
|
|
||||||
echo " -i: initial peer"
|
|
||||||
echo " -d: device" >&2
|
|
||||||
echo " -p: server identity path" >&2
|
|
||||||
echo " -b: block_ids" >&2
|
|
||||||
echo " -a: host maddrs" >&2
|
|
||||||
echo " -t: whether to run local tests" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ ! $# -ge 8 ]; then
|
|
||||||
instructions
|
|
||||||
fi
|
|
||||||
|
|
||||||
while getopts ":m:i:d:p:b:a:t:" option; do
|
|
||||||
case $option in
|
|
||||||
m) MODEL_NAME=${OPTARG}
|
|
||||||
;;
|
|
||||||
i) INITIAL_PEER=${OPTARG}
|
|
||||||
;;
|
|
||||||
d) DEVICE=${OPTARG}
|
|
||||||
;;
|
|
||||||
p) SERVER_ID_PATH=${OPTARG}
|
|
||||||
;;
|
|
||||||
b) BLOCK_IDS=${OPTARG}
|
|
||||||
;;
|
|
||||||
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
|
|
||||||
;;
|
|
||||||
t) RUN_LOCAL_TESTS=true
|
|
||||||
;;
|
|
||||||
\?) instructions
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
|
|
||||||
echo "=========="
|
|
||||||
echo "= Config ="
|
|
||||||
echo "=========="
|
|
||||||
echo "Model name: ${MODEL_NAME}"
|
|
||||||
echo "Initial peer: ${INITIAL_PEER}"
|
|
||||||
echo "Device: ${DEVICE}"
|
|
||||||
echo "Server name: ${SERVER_ID_PATH}"
|
|
||||||
echo "Server address: ${HOST_MADDR}"
|
|
||||||
echo "Bloom blocks: ${BLOCK_IDS}"
|
|
||||||
|
|
||||||
|
|
||||||
###########################
|
|
||||||
# Install or activate env #
|
|
||||||
###########################
|
|
||||||
|
|
||||||
# TODO fix bug with self calling
|
|
||||||
source ~/miniconda3/etc/profile.d/conda.sh
|
|
||||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
|
||||||
conda activate bloom-demo
|
|
||||||
else
|
|
||||||
conda create -y --name bloom-demo python=3.8.12 pip
|
|
||||||
conda activate bloom-demo
|
|
||||||
|
|
||||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
||||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
pip install -i https://pypi.org/simple -r .
|
|
||||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
|
||||||
fi
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Run server #
|
|
||||||
##############
|
|
||||||
|
|
||||||
python -m petals.cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
|
|
||||||
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log
|
|
@ -1,51 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from tqdm.auto import trange
|
|
||||||
from transformers import BloomConfig
|
|
||||||
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
|
|
||||||
|
|
||||||
from petals.bloom.block import BloomBlock
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
|
|
||||||
|
|
||||||
|
|
||||||
def print_device_info(device=None):
|
|
||||||
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
|
||||||
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
|
||||||
logger.info(f"Using device: {device}")
|
|
||||||
|
|
||||||
# Additional Info when using cuda
|
|
||||||
if device.type == "cuda":
|
|
||||||
logger.info(torch.cuda.get_device_name(0))
|
|
||||||
logger.info(f"Memory Usage:")
|
|
||||||
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
|
|
||||||
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
|
|
||||||
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
|
|
||||||
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
|
|
||||||
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
|
|
||||||
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.device is None:
|
|
||||||
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
config = BloomConfig.from_json_file(args.config)
|
|
||||||
block = BloomBlock(config).to(args.device)
|
|
||||||
|
|
||||||
cache = None
|
|
||||||
|
|
||||||
for i in trange(args.num_steps):
|
|
||||||
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
|
|
||||||
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
|
||||||
|
|
||||||
print_device_info(args.device)
|
|
@ -1,5 +0,0 @@
|
|||||||
device=cpu
|
|
||||||
block_ids=2:3
|
|
||||||
id_path=./server.id
|
|
||||||
maddr=/ip4/127.0.0.1/tcp/30000
|
|
||||||
#
|
|
@ -1,6 +0,0 @@
|
|||||||
name=bloom-peer-0.bloom.net
|
|
||||||
device=cpu
|
|
||||||
block_ids=1:3
|
|
||||||
id_path=./server.id
|
|
||||||
maddr=/ip4/0.0.0.0/tcp/30000
|
|
||||||
#
|
|
106
src/petals/cli/run_dht.py
Normal file
106
src/petals/cli/run_dht.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
A copy of run_dht.py from hivemind with the ReachabilityProtocol added:
|
||||||
|
https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py
|
||||||
|
|
||||||
|
This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.
|
||||||
|
|
||||||
|
This may be eventually merged to the hivemind upstream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from secrets import token_hex
|
||||||
|
|
||||||
|
from hivemind.dht import DHT, DHTNode
|
||||||
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||||
|
from hivemind.utils.networking import log_visible_maddrs
|
||||||
|
|
||||||
|
from petals.server.reachability import ReachabilityProtocol
|
||||||
|
|
||||||
|
use_hivemind_log_handler("in_root_logger")
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def report_status(dht: DHT, node: DHTNode):
|
||||||
|
logger.info(
|
||||||
|
f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
|
||||||
|
f"are in the local routing table "
|
||||||
|
)
|
||||||
|
logger.debug(f"Routing table contents: {node.protocol.routing_table}")
|
||||||
|
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
|
||||||
|
logger.debug(f"Local storage contents: {node.protocol.storage}")
|
||||||
|
|
||||||
|
# Contact peers and keep the routing table healthy (remove stale PeerIDs)
|
||||||
|
await node.get(f"heartbeat_{token_hex(16)}", latest=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial_peers",
|
||||||
|
nargs="*",
|
||||||
|
help="Multiaddrs of the peers that will welcome you into the existing DHT. "
|
||||||
|
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host_maddrs",
|
||||||
|
nargs="*",
|
||||||
|
default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"],
|
||||||
|
help="Multiaddrs to listen for external connections from other DHT instances. "
|
||||||
|
"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--announce_maddrs",
|
||||||
|
nargs="*",
|
||||||
|
help="Visible multiaddrs the host announces for external connections from other DHT instances",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_ipfs",
|
||||||
|
action="store_true",
|
||||||
|
help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
|
||||||
|
"part of the multiaddrs for the initial_peers "
|
||||||
|
"(no need to specify a particular IPv4/IPv6 host and port)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--identity_path",
|
||||||
|
help="Path to a private key file. If defined, makes the peer ID deterministic. "
|
||||||
|
"If the file does not exist, writes a new private key to this file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_relay",
|
||||||
|
action="store_false",
|
||||||
|
dest="use_relay",
|
||||||
|
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_auto_relay",
|
||||||
|
action="store_true",
|
||||||
|
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dht = DHT(
|
||||||
|
start=True,
|
||||||
|
initial_peers=args.initial_peers,
|
||||||
|
host_maddrs=args.host_maddrs,
|
||||||
|
announce_maddrs=args.announce_maddrs,
|
||||||
|
use_ipfs=args.use_ipfs,
|
||||||
|
identity_path=args.identity_path,
|
||||||
|
use_relay=args.use_relay,
|
||||||
|
use_auto_relay=args.use_auto_relay,
|
||||||
|
)
|
||||||
|
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
|
||||||
|
|
||||||
|
reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
dht.run_coroutine(report_status, return_future=False)
|
||||||
|
time.sleep(args.refresh_period)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,109 +0,0 @@
|
|||||||
# !/usr/bin/env bash
|
|
||||||
|
|
||||||
#################
|
|
||||||
# Parse options #
|
|
||||||
#################
|
|
||||||
|
|
||||||
instructions() {
|
|
||||||
echo "Usage: $0 [-n] [-c]" >&2
|
|
||||||
echo " -n: number of servers to run" >&2
|
|
||||||
echo " -c: path to the server configs" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ $# != 4 ]; then
|
|
||||||
instructions
|
|
||||||
fi
|
|
||||||
|
|
||||||
while getopts ":n:c:t:" option; do
|
|
||||||
case $option in
|
|
||||||
n) NUM_SERVERS=${OPTARG}
|
|
||||||
;;
|
|
||||||
c) CONFIG_PATH=${OPTARG}
|
|
||||||
;;
|
|
||||||
\?) instructions
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
|
|
||||||
###########################
|
|
||||||
# Install or activate env #
|
|
||||||
###########################
|
|
||||||
|
|
||||||
source ~/miniconda3/etc/profile.d/conda.sh
|
|
||||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
|
||||||
conda activate bloom-demo
|
|
||||||
else
|
|
||||||
conda create -y --name bloom-demo python=3.8.12 pip
|
|
||||||
conda activate bloom-demo
|
|
||||||
|
|
||||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
||||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
pip install -i https://pypi.org/simple -r .
|
|
||||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
#######################
|
|
||||||
# Create Initial peer #
|
|
||||||
#######################
|
|
||||||
|
|
||||||
hivemind-dht &> tmp.out &
|
|
||||||
sleep 5
|
|
||||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
|
|
||||||
echo "Initial peer: ${INITIAL_PEER}"
|
|
||||||
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# Initialize the config file #
|
|
||||||
##############################
|
|
||||||
|
|
||||||
typeset -A cfg
|
|
||||||
cfg=( # set default values in config array
|
|
||||||
[device]="cpu"
|
|
||||||
[block_ids]="1:2"
|
|
||||||
[id_path]="server.id"
|
|
||||||
[maddr]="/ip4/127.0.0.1/tcp/30000"
|
|
||||||
)
|
|
||||||
|
|
||||||
###############
|
|
||||||
# Run servers #
|
|
||||||
###############
|
|
||||||
|
|
||||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
|
||||||
do
|
|
||||||
###############
|
|
||||||
# Read config #
|
|
||||||
###############
|
|
||||||
|
|
||||||
while read line
|
|
||||||
do
|
|
||||||
if echo $line | grep -F = &>/dev/null
|
|
||||||
then
|
|
||||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
|
||||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
|
||||||
fi
|
|
||||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
|
||||||
|
|
||||||
echo "=== Server #${SERVER_ID} ==="
|
|
||||||
echo "Server ID: ${cfg[id_path]}"
|
|
||||||
echo "Device: ${cfg[device]}"
|
|
||||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
|
||||||
echo "Host maddr: ${cfg[maddr]}"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Run server #
|
|
||||||
##############
|
|
||||||
|
|
||||||
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
|
|
||||||
done
|
|
||||||
|
|
||||||
#####################
|
|
||||||
# Kill initial peer #
|
|
||||||
#####################
|
|
||||||
|
|
||||||
sleep 10
|
|
||||||
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
|
|
||||||
rm tmp.out
|
|
@ -1,110 +0,0 @@
|
|||||||
# !/usr/bin/env bash
|
|
||||||
|
|
||||||
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
|
|
||||||
|
|
||||||
#################
|
|
||||||
# Parse options #
|
|
||||||
#################
|
|
||||||
|
|
||||||
instructions() {
|
|
||||||
echo "Usage: $0 [-u] [-n] [-c]" >&2
|
|
||||||
echo " -u: username" >&2
|
|
||||||
echo " -n: number of servers to run" >&2
|
|
||||||
echo " -c: path to the server configs" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ $# != 6 ]; then
|
|
||||||
instructions
|
|
||||||
fi
|
|
||||||
|
|
||||||
while getopts ":u:n:c:" option; do
|
|
||||||
case $option in
|
|
||||||
u) USERNAME=${OPTARG}
|
|
||||||
;;
|
|
||||||
n) NUM_SERVERS=${OPTARG}
|
|
||||||
;;
|
|
||||||
c) CONFIG_PATH=${OPTARG}
|
|
||||||
;;
|
|
||||||
\?) instructions
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
|
|
||||||
###########################
|
|
||||||
# Install or activate env #
|
|
||||||
###########################
|
|
||||||
|
|
||||||
source ~/miniconda3/etc/profile.d/conda.sh
|
|
||||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
|
||||||
conda activate bloom-demo
|
|
||||||
else
|
|
||||||
conda create -y --name bloom-demo python=3.8.12 pip
|
|
||||||
conda activate bloom-demo
|
|
||||||
|
|
||||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
|
||||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
pip install -i https://pypi.org/simple -r .
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
#######################
|
|
||||||
# Create Initial peer #
|
|
||||||
#######################
|
|
||||||
|
|
||||||
hivemind-dht &> tmp.out &
|
|
||||||
|
|
||||||
sleep 5
|
|
||||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
|
|
||||||
rm tmp.out
|
|
||||||
echo "Initial peer: ${INITIAL_PEER}"
|
|
||||||
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# Initialize the config file #
|
|
||||||
##############################
|
|
||||||
|
|
||||||
typeset -A cfg
|
|
||||||
cfg=( # set default values in config array
|
|
||||||
[name]=""
|
|
||||||
[device]="cpu"
|
|
||||||
[block_ids]="1:2"
|
|
||||||
[id_path]="server.id"
|
|
||||||
[maddr]="/ip4/0.0.0.0/tcp/30000"
|
|
||||||
)
|
|
||||||
|
|
||||||
###############
|
|
||||||
# Run servers #
|
|
||||||
###############
|
|
||||||
|
|
||||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
|
||||||
do
|
|
||||||
###############
|
|
||||||
# Read config #
|
|
||||||
###############
|
|
||||||
|
|
||||||
while read line
|
|
||||||
do
|
|
||||||
if echo $line | grep -F = &>/dev/null
|
|
||||||
then
|
|
||||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
|
||||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
|
||||||
fi
|
|
||||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
|
||||||
|
|
||||||
SERVER_NAME="${USERNAME}@${cfg[name]}"
|
|
||||||
echo "=== Server #${SERVER_ID} ==="
|
|
||||||
echo "Server name ${SERVER_NAME}"
|
|
||||||
echo "Server ID: ${cfg[id_path]}"
|
|
||||||
echo "Device: ${cfg[device]}"
|
|
||||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
|
||||||
echo "Host maddr: ${cfg[maddr]}"
|
|
||||||
echo "================="
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Run server #
|
|
||||||
##############
|
|
||||||
|
|
||||||
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
|
|
||||||
done
|
|
@ -6,10 +6,12 @@ from hivemind.utils.limits import increase_file_limit
|
|||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
from humanfriendly import parse_size
|
from humanfriendly import parse_size
|
||||||
|
|
||||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
from petals.server.server import Server
|
from petals.server.server import Server
|
||||||
|
from petals.utils.convert_block import QuantType
|
||||||
|
from petals.utils.version import validate_version
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -23,10 +25,16 @@ def main():
|
|||||||
help="path or name of a pretrained model, converted with cli/convert_model.py")
|
help="path or name of a pretrained model, converted with cli/convert_model.py")
|
||||||
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
|
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
|
||||||
|
|
||||||
|
parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard")
|
||||||
|
|
||||||
|
group = parser.add_mutually_exclusive_group(required=False)
|
||||||
|
group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
|
||||||
|
group.add_argument("--use_auth_token", action="store_true", dest="token",
|
||||||
|
help="Read token saved by `huggingface-cli login")
|
||||||
|
|
||||||
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
|
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
|
||||||
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
|
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
|
||||||
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
|
parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix")
|
||||||
"use the same name as in the converted model.")
|
|
||||||
|
|
||||||
parser.add_argument('--port', type=int, required=False,
|
parser.add_argument('--port', type=int, required=False,
|
||||||
help='Port this server listens to. '
|
help='Port this server listens to. '
|
||||||
@ -38,25 +46,39 @@ def main():
|
|||||||
'This is a simplified way to set the --announce_maddrs option (see below).'
|
'This is a simplified way to set the --announce_maddrs option (see below).'
|
||||||
'Default: server announces IPv4/IPv6 addresses of your network interfaces')
|
'Default: server announces IPv4/IPv6 addresses of your network interfaces')
|
||||||
|
|
||||||
|
parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay",
|
||||||
|
help="Do not look for libp2p relays to become reachable if we are behind NAT/firewall")
|
||||||
|
|
||||||
parser.add_argument('--host_maddrs', nargs='+', required=False,
|
parser.add_argument('--host_maddrs', nargs='+', required=False,
|
||||||
help='Multiaddrs to listen for external connections from other peers')
|
help='Multiaddrs to listen for external connections from other peers')
|
||||||
parser.add_argument('--announce_maddrs', nargs='+', required=False,
|
parser.add_argument('--announce_maddrs', nargs='+', required=False,
|
||||||
help='Visible multiaddrs the host announces for external connections from other peers')
|
help='Visible multiaddrs the host announces for external connections from other peers')
|
||||||
|
|
||||||
|
parser.add_argument('--daemon_startup_timeout', type=float, default=60,
|
||||||
|
help='Timeout for the libp2p daemon connecting to initial peers')
|
||||||
|
|
||||||
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
|
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
|
||||||
|
|
||||||
parser.add_argument('--num_handlers', type=int, default=8, required=False,
|
parser.add_argument('--num_handlers', type=int, default=8, required=False,
|
||||||
help='server will use this many processes to handle incoming requests')
|
help='server will use this many processes to handle incoming requests')
|
||||||
parser.add_argument('--min_batch_size', type=int, default=1,
|
|
||||||
help='Minimum required batch size for all operations (in total tokens)')
|
|
||||||
parser.add_argument('--max_batch_size', type=int, default=2048,
|
|
||||||
help='The total number of tokens in the same batch will not exceed this value')
|
|
||||||
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
|
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
|
||||||
help='Pre-form this many subsequent batches while GPU is processing the current one')
|
help='Pre-form this many subsequent batches while GPU is processing the current one')
|
||||||
parser.add_argument('--sender_threads', type=int, default=1, required=False,
|
parser.add_argument('--sender_threads', type=int, default=1, required=False,
|
||||||
help='Use this many threads to pass results/exceptions from Runtime to Pools')
|
help='Use this many threads to pass results/exceptions from Runtime to Pools')
|
||||||
parser.add_argument('--inference_max_length', type=int, default=2048,
|
|
||||||
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
|
parser.add_argument('--inference_max_length', type=int, default=None,
|
||||||
|
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
|
||||||
|
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
|
||||||
|
parser.add_argument('--min_batch_size', type=int, default=1,
|
||||||
|
help='Minimum required batch size for all operations (in total tokens)')
|
||||||
|
parser.add_argument('--max_batch_size', type=int, default=None,
|
||||||
|
help='The total number of tokens in the same batch will not exceed this value. '
|
||||||
|
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
|
||||||
|
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
|
||||||
|
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
|
||||||
|
parser.add_argument('--attn_cache_tokens', type=int, default=None,
|
||||||
|
help='The number of past attention key/value pairs that will be stored between inference steps. '
|
||||||
|
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
|
||||||
|
|
||||||
parser.add_argument('--cache_dir', type=str, default=None,
|
parser.add_argument('--cache_dir', type=str, default=None,
|
||||||
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
|
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
|
||||||
@ -71,18 +93,13 @@ def main():
|
|||||||
|
|
||||||
parser.add_argument('--device', type=str, default=None, required=False,
|
parser.add_argument('--device', type=str, default=None, required=False,
|
||||||
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
|
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
|
||||||
parser.add_argument("--torch_dtype", type=str, default="auto",
|
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
|
||||||
help="Use this dtype to store block weights and do computations. "
|
help="Use this dtype to store block weights and do computations. "
|
||||||
"By default, respect the dtypes in the pre-trained state dict.")
|
"By default, respect the dtypes in the pre-trained state dict.")
|
||||||
parser.add_argument('--attn_cache_size', type=str, default=None,
|
parser.add_argument('--alloc_timeout', type=float, default=1,
|
||||||
help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
|
|
||||||
'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
|
|
||||||
'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
|
|
||||||
'The latter is the hidden size of the bigscience/bloom-petals model.')
|
|
||||||
parser.add_argument('--alloc_timeout', type=float, default=60,
|
|
||||||
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
|
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
|
||||||
'before rejecting the request')
|
'before rejecting the request')
|
||||||
parser.add_argument('--revision', type=str, default='main',
|
parser.add_argument('--revision', type=str, default=None,
|
||||||
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
||||||
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
||||||
|
|
||||||
@ -93,7 +110,7 @@ def main():
|
|||||||
'If set to "auto" (default), the script evaluates network and compute throughput '
|
'If set to "auto" (default), the script evaluates network and compute throughput '
|
||||||
'on the first run and uses these estimates for future runs. '
|
'on the first run and uses these estimates for future runs. '
|
||||||
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
|
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
|
||||||
parser.add_argument('--update_period', type=float, required=False, default=150,
|
parser.add_argument('--update_period', type=float, required=False, default=120,
|
||||||
help='Server will report blocks to DHT once in this many seconds')
|
help='Server will report blocks to DHT once in this many seconds')
|
||||||
parser.add_argument('--expiration', type=float, required=False, default=None,
|
parser.add_argument('--expiration', type=float, required=False, default=None,
|
||||||
help='DHT entries will expire after this many seconds')
|
help='DHT entries will expire after this many seconds')
|
||||||
@ -105,7 +122,7 @@ def main():
|
|||||||
help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
|
help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
|
||||||
|
|
||||||
group = parser.add_mutually_exclusive_group()
|
group = parser.add_mutually_exclusive_group()
|
||||||
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
|
group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS,
|
||||||
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
|
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
|
||||||
group.add_argument('--new_swarm', action='store_true',
|
group.add_argument('--new_swarm', action='store_true',
|
||||||
help='Start a new private swarm (i.e., do not connect to any initial peers)')
|
help='Start a new private swarm (i.e., do not connect to any initial peers)')
|
||||||
@ -127,16 +144,23 @@ def main():
|
|||||||
parser.add_argument("--mean_balance_check_period", type=float, default=60,
|
parser.add_argument("--mean_balance_check_period", type=float, default=60,
|
||||||
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
|
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
|
||||||
|
|
||||||
parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
|
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
|
||||||
parser.add_argument('--load_in_8bit', type=str, default=None,
|
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
|
||||||
help="Convert the loaded model into mixed-8bit quantized model. "
|
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
|
||||||
"Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
|
"Default: 'int8' if GPU is available, 'none' otherwise")
|
||||||
|
parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
|
||||||
|
help=
|
||||||
|
"Split each block between the specified GPUs such that each device holds a portion of every "
|
||||||
|
"weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
|
||||||
|
|
||||||
parser.add_argument("--skip_reachability_check", action='store_true',
|
parser.add_argument("--skip_reachability_check", action='store_true',
|
||||||
help="Skip checking this server's reachability via health.petals.ml "
|
help="Skip checking this server's reachability via health.petals.dev "
|
||||||
"when connecting to the public swarm. If you connect to a private swarm, "
|
"when connecting to the public swarm. If you connect to a private swarm, "
|
||||||
"the check is skipped by default. Use this option only if you know what you are doing")
|
"the check is skipped by default. Use this option only if you know what you are doing")
|
||||||
|
|
||||||
|
parser.add_argument("--adapters", nargs='*', default=(),
|
||||||
|
help="List of pre-loaded LoRA adapters that can be used for inference or training")
|
||||||
|
|
||||||
# fmt:on
|
# fmt:on
|
||||||
args = vars(parser.parse_args())
|
args = vars(parser.parse_args())
|
||||||
args.pop("config", None)
|
args.pop("config", None)
|
||||||
@ -159,19 +183,14 @@ def main():
|
|||||||
assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)"
|
assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)"
|
||||||
announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"]
|
announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"]
|
||||||
|
|
||||||
|
args["startup_timeout"] = args.pop("daemon_startup_timeout")
|
||||||
|
|
||||||
if args.pop("increase_file_limit"):
|
if args.pop("increase_file_limit"):
|
||||||
increase_file_limit()
|
increase_file_limit()
|
||||||
|
|
||||||
compression_type = args.pop("compression").upper()
|
compression_type = args.pop("compression").upper()
|
||||||
compression = getattr(CompressionType, compression_type)
|
compression = getattr(CompressionType, compression_type)
|
||||||
|
|
||||||
attn_cache_size = args.pop("attn_cache_size")
|
|
||||||
if attn_cache_size is not None:
|
|
||||||
attn_cache_size = parse_size(attn_cache_size)
|
|
||||||
assert isinstance(
|
|
||||||
attn_cache_size, (int, type(None))
|
|
||||||
), "Unrecognized value for --attn_cache_size. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
|
|
||||||
|
|
||||||
max_disk_space = args.pop("max_disk_space")
|
max_disk_space = args.pop("max_disk_space")
|
||||||
if max_disk_space is not None:
|
if max_disk_space is not None:
|
||||||
max_disk_space = parse_size(max_disk_space)
|
max_disk_space = parse_size(max_disk_space)
|
||||||
@ -182,9 +201,11 @@ def main():
|
|||||||
if args.pop("new_swarm"):
|
if args.pop("new_swarm"):
|
||||||
args["initial_peers"] = []
|
args["initial_peers"] = []
|
||||||
|
|
||||||
load_in_8bit = args.pop("load_in_8bit")
|
quant_type = args.pop("quant_type")
|
||||||
if load_in_8bit is not None:
|
if quant_type is not None:
|
||||||
args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
|
args["quant_type"] = QuantType[quant_type.upper()]
|
||||||
|
|
||||||
|
validate_version()
|
||||||
|
|
||||||
server = Server(
|
server = Server(
|
||||||
**args,
|
**args,
|
||||||
@ -192,7 +213,6 @@ def main():
|
|||||||
announce_maddrs=announce_maddrs,
|
announce_maddrs=announce_maddrs,
|
||||||
compression=compression,
|
compression=compression,
|
||||||
max_disk_space=max_disk_space,
|
max_disk_space=max_disk_space,
|
||||||
attn_cache_size=attn_cache_size,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
server.run()
|
server.run()
|
||||||
|
@ -1,10 +1,4 @@
|
|||||||
from petals.client.inference_session import InferenceSession
|
from petals.client.inference_session import InferenceSession
|
||||||
from petals.client.remote_model import (
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
DistributedBloomConfig,
|
|
||||||
DistributedBloomForCausalLM,
|
|
||||||
DistributedBloomForSequenceClassification,
|
|
||||||
DistributedBloomModel,
|
|
||||||
)
|
|
||||||
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
|
|
||||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||||
|
94
src/petals/client/from_pretrained.py
Normal file
94
src/petals/client/from_pretrained.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers import BloomPreTrainedModel, modeling_utils
|
||||||
|
|
||||||
|
from petals.utils.version import get_compatible_model_repo
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FromPretrainedMixin:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_name_or_path: Union[str, os.PathLike, None],
|
||||||
|
*args,
|
||||||
|
low_cpu_mem_usage: Optional[bool] = None,
|
||||||
|
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_name_or_path = get_compatible_model_repo(model_name_or_path)
|
||||||
|
if low_cpu_mem_usage is None:
|
||||||
|
low_cpu_mem_usage = True
|
||||||
|
if torch_dtype is None:
|
||||||
|
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
|
||||||
|
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
|
||||||
|
torch_dtype = "auto"
|
||||||
|
|
||||||
|
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
|
||||||
|
return super().from_pretrained(
|
||||||
|
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
||||||
|
"low_cpu_mem_usage(`bool`, *optional*)",
|
||||||
|
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
||||||
|
).replace(
|
||||||
|
"torch_dtype (`str` or `torch.dtype`, *optional*)",
|
||||||
|
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_shard_config = threading.local()
|
||||||
|
_shard_config.ignored_keys = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def ignore_keys(patterns: List[str]):
|
||||||
|
try:
|
||||||
|
prev_patterns = _shard_config.ignored_keys
|
||||||
|
_shard_config.ignored_keys = patterns
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_shard_config.ignored_keys = prev_patterns
|
||||||
|
|
||||||
|
|
||||||
|
def patched_get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||||
|
) -> Tuple[List[str], dict]:
|
||||||
|
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
|
||||||
|
|
||||||
|
should_ignore_keys = _shard_config.ignored_keys is not None
|
||||||
|
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
|
||||||
|
with tempdir_ctx as tempdir:
|
||||||
|
if should_ignore_keys:
|
||||||
|
with open(index_filename) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
n_original_shards = len(set(index["weight_map"].values()))
|
||||||
|
|
||||||
|
index["weight_map"] = {
|
||||||
|
param_name: filename
|
||||||
|
for param_name, filename in index["weight_map"].items()
|
||||||
|
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
|
||||||
|
}
|
||||||
|
n_loaded_shards = len(set(index["weight_map"].values()))
|
||||||
|
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
|
||||||
|
|
||||||
|
# Replace the original index with a patched JSON, where ignored keys are removed
|
||||||
|
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
|
||||||
|
with open(index_filename, "w") as f:
|
||||||
|
json.dump(index, f)
|
||||||
|
|
||||||
|
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
|
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files
|
@ -2,13 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from typing import AsyncIterator, List, Optional
|
import uuid
|
||||||
|
from typing import AsyncIterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from hivemind import (
|
from hivemind import (
|
||||||
P2P,
|
|
||||||
MSGPackSerializer,
|
MSGPackSerializer,
|
||||||
anext,
|
anext,
|
||||||
deserialize_torch_tensor,
|
deserialize_torch_tensor,
|
||||||
@ -17,15 +16,15 @@ from hivemind import (
|
|||||||
serialize_torch_tensor,
|
serialize_torch_tensor,
|
||||||
)
|
)
|
||||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||||
from hivemind.p2p import P2PHandlerError, StubBase
|
from hivemind.p2p import P2P
|
||||||
from hivemind.proto import runtime_pb2
|
from hivemind.proto import runtime_pb2
|
||||||
|
|
||||||
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
|
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
|
||||||
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
||||||
from petals.server.handler import TransformerConnectionHandler
|
from petals.server.handler import TransformerConnectionHandler
|
||||||
from petals.utils.misc import DUMMY, is_dummy
|
from petals.utils.misc import DUMMY, is_dummy
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _ServerInferenceSession:
|
class _ServerInferenceSession:
|
||||||
@ -37,35 +36,48 @@ class _ServerInferenceSession:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
config: SequenceManagerConfig,
|
||||||
|
span: RemoteSpanInfo,
|
||||||
uid: ModuleUID,
|
uid: ModuleUID,
|
||||||
rpc_info: RPCInfo,
|
rpc_info: RPCInfo,
|
||||||
inputs_queue: asyncio.Queue,
|
inputs_queue: asyncio.Queue,
|
||||||
outputs_aiter: AsyncIterator,
|
outputs_aiter: AsyncIterator,
|
||||||
*,
|
*,
|
||||||
timeout: float,
|
|
||||||
max_length: int,
|
max_length: int,
|
||||||
**metadata,
|
**metadata,
|
||||||
):
|
):
|
||||||
self.uid, self.rpc_info = uid, rpc_info
|
self.config = config
|
||||||
|
self.span, self.uid, self.rpc_info = span, uid, rpc_info
|
||||||
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
||||||
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
||||||
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
||||||
self.timeout = timeout
|
self.session_id = str(uuid.uuid4())
|
||||||
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata))
|
self.session_metadata = dict(max_length=max_length, **metadata)
|
||||||
self.stepped = False
|
self.stepped = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
|
self._position = 0
|
||||||
|
self.history = None # Used in case of server failures to regenerate attention caches on new servers
|
||||||
|
self.next_session = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
|
cls,
|
||||||
|
config: SequenceManagerConfig,
|
||||||
|
p2p: P2P,
|
||||||
|
span: RemoteSpanInfo,
|
||||||
|
uid: ModuleUID,
|
||||||
|
rpc_info: RPCInfo,
|
||||||
|
**metadata,
|
||||||
) -> _ServerInferenceSession:
|
) -> _ServerInferenceSession:
|
||||||
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
||||||
|
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
|
||||||
inputs_queue = asyncio.Queue()
|
inputs_queue = asyncio.Queue()
|
||||||
outputs_stream = await asyncio.wait_for(
|
outputs_stream = await asyncio.wait_for(
|
||||||
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
||||||
timeout,
|
config.connect_timeout,
|
||||||
)
|
)
|
||||||
return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
|
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
||||||
@ -77,55 +89,97 @@ class _ServerInferenceSession:
|
|||||||
|
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
new_hidden_states: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
prompts: Optional[torch.Tensor] = None,
|
prompts: Optional[torch.Tensor] = None,
|
||||||
hypo_ids: Optional[torch.Tensor] = None,
|
hypo_ids: Optional[torch.Tensor] = None,
|
||||||
|
*,
|
||||||
|
step_id: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Inference step: send a chunk of input tesors and receive a chunk of outputs
|
Inference step: send a chunk of input tensors and receive a chunk of outputs
|
||||||
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
||||||
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
||||||
"""
|
"""
|
||||||
if self.closed:
|
if self.closed:
|
||||||
raise Exception("Session is closed, cannot perform step")
|
raise Exception("Session is closed, cannot perform step")
|
||||||
|
|
||||||
|
n_input_tokens = inputs.shape[1]
|
||||||
|
if self.history is None:
|
||||||
|
self.history = inputs
|
||||||
|
elif self.history.shape[1] == self._position:
|
||||||
|
self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
|
||||||
|
assert self.history.shape[1] == self._position + n_input_tokens, (
|
||||||
|
f"Broken input cache: span={self.span} shape={self.history.shape} "
|
||||||
|
f"position={self._position} n_input_tokens={n_input_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.stepped:
|
||||||
|
inputs = self.history # Pass full inputs including prefix
|
||||||
|
else:
|
||||||
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
||||||
|
|
||||||
if prompts is None or is_dummy(prompts):
|
if prompts is None or is_dummy(prompts):
|
||||||
prompts = DUMMY
|
prompts = DUMMY
|
||||||
else:
|
else:
|
||||||
assert prompts.ndim == 4, "deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]"
|
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
|
||||||
assert prompts.shape[0] == self.num_blocks
|
assert prompts.shape[0] == self.num_blocks
|
||||||
assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
|
assert prompts.shape[1] in (inputs.shape[0], 1)
|
||||||
assert prompts.shape[2] <= new_hidden_states.shape[1]
|
assert prompts.shape[2] <= inputs.shape[1]
|
||||||
assert prompts.shape[3] == new_hidden_states.shape[2]
|
assert prompts.shape[3] == inputs.shape[2]
|
||||||
|
|
||||||
if hypo_ids is None or is_dummy(hypo_ids):
|
if hypo_ids is None or is_dummy(hypo_ids):
|
||||||
hypo_ids = DUMMY
|
hypo_ids = DUMMY
|
||||||
else:
|
else:
|
||||||
assert len(hypo_ids) == len(new_hidden_states)
|
assert len(hypo_ids) == len(inputs)
|
||||||
assert hypo_ids.dtype == torch.int64
|
assert hypo_ids.dtype == torch.int64
|
||||||
|
|
||||||
# serialize inputs and put them into the queue
|
# serialize inputs and put them into the queue
|
||||||
inputs = (new_hidden_states, prompts, hypo_ids)
|
input_tensors = (inputs, prompts, hypo_ids)
|
||||||
|
|
||||||
|
request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
||||||
|
if not self.stepped:
|
||||||
|
request_metadata.update(self.session_metadata)
|
||||||
|
elif self.config.use_server_to_server:
|
||||||
|
next_servers = self._collect_next_servers()
|
||||||
|
if next_servers:
|
||||||
|
request_metadata["next_servers"] = next_servers
|
||||||
|
|
||||||
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
||||||
self._step(
|
self._step(
|
||||||
runtime_pb2.ExpertRequest(
|
runtime_pb2.ExpertRequest(
|
||||||
uid=self.uid,
|
uid=self.uid,
|
||||||
tensors=[
|
tensors=[
|
||||||
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
|
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
|
||||||
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
|
for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
|
||||||
],
|
],
|
||||||
metadata=self._serialized_metadata if not self.stepped else None,
|
metadata=MSGPackSerializer.dumps(request_metadata),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
||||||
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
|
assert (
|
||||||
|
outputs[0].shape == inputs.shape
|
||||||
|
), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
|
||||||
|
|
||||||
|
self._position += n_input_tokens
|
||||||
|
|
||||||
return outputs[0]
|
return outputs[0]
|
||||||
|
|
||||||
|
def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
|
||||||
|
next_servers = []
|
||||||
|
session = self.next_session
|
||||||
|
while session is not None and session.stepped:
|
||||||
|
next_servers.append(
|
||||||
|
(session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)
|
||||||
|
)
|
||||||
|
session = session.next_session
|
||||||
|
return next_servers
|
||||||
|
|
||||||
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
||||||
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
||||||
await self._inputs_queue.put(inputs_serialized)
|
await self._inputs_queue.put(inputs_serialized)
|
||||||
self.stepped = True
|
self.stepped = True
|
||||||
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
|
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Finish a given inference session, close the underlying connection"""
|
"""Finish a given inference session, close the underlying connection"""
|
||||||
@ -162,17 +216,18 @@ class InferenceSession:
|
|||||||
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int):
|
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
|
||||||
self._sequence_manager = sequence_manager
|
self._sequence_manager = sequence_manager
|
||||||
self._p2p = p2p
|
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._chosen_spans = []
|
|
||||||
self._server_sessions = []
|
self._server_sessions = []
|
||||||
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
|
||||||
self._position = 0
|
self._position = 0
|
||||||
self._max_length = max_length
|
self._max_length = max_length
|
||||||
self.token_ids = []
|
self.token_ids = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_blocks(self) -> int:
|
||||||
|
return len(self._sequence_manager)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def position(self) -> int:
|
def position(self) -> int:
|
||||||
return self._position
|
return self._position
|
||||||
@ -181,15 +236,15 @@ class InferenceSession:
|
|||||||
server_sessions = []
|
server_sessions = []
|
||||||
try:
|
try:
|
||||||
for span in chosen_spans:
|
for span in chosen_spans:
|
||||||
stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
|
|
||||||
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
||||||
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
|
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
|
||||||
session = RemoteExpertWorker.run_coroutine(
|
session = RemoteExpertWorker.run_coroutine(
|
||||||
_ServerInferenceSession.create(
|
_ServerInferenceSession.create(
|
||||||
stub,
|
self._sequence_manager.config,
|
||||||
|
self._sequence_manager.state.p2p,
|
||||||
|
span,
|
||||||
span_uids,
|
span_uids,
|
||||||
rpc_info=self._sequence_manager.rpc_info,
|
rpc_info=self._sequence_manager.rpc_info,
|
||||||
timeout=self._sequence_manager.request_timeout,
|
|
||||||
max_length=self._max_length,
|
max_length=self._max_length,
|
||||||
**metadata,
|
**metadata,
|
||||||
)
|
)
|
||||||
@ -209,7 +264,7 @@ class InferenceSession:
|
|||||||
logger.debug("Caught exception while closing connection to server:", exc_info=True)
|
logger.debug("Caught exception while closing connection to server:", exc_info=True)
|
||||||
|
|
||||||
def __enter__(self) -> "InferenceSession":
|
def __enter__(self) -> "InferenceSession":
|
||||||
assert not self._closed and not self._chosen_spans
|
assert not self._closed and not self._server_sessions
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
||||||
@ -217,16 +272,17 @@ class InferenceSession:
|
|||||||
if torch.is_grad_enabled():
|
if torch.is_grad_enabled():
|
||||||
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
||||||
|
|
||||||
n_blocks = len(self._sequence_manager)
|
|
||||||
if prompts is None or is_dummy(prompts):
|
if prompts is None or is_dummy(prompts):
|
||||||
prompts = DUMMY
|
prompts = DUMMY
|
||||||
else:
|
else:
|
||||||
assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
|
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
|
||||||
|
assert prompts.shape[0] == self.num_blocks
|
||||||
|
|
||||||
inputs_device = inputs.device
|
inputs_device = inputs.device
|
||||||
inputs_dtype = inputs.dtype
|
inputs_dtype = inputs.dtype
|
||||||
inputs = inputs.cpu()
|
inputs = inputs.cpu()
|
||||||
prompts = prompts.cpu()
|
prompts = prompts.cpu()
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
|
||||||
n_input_tokens = inputs.shape[1]
|
n_input_tokens = inputs.shape[1]
|
||||||
if self._position + n_input_tokens > self._max_length:
|
if self._position + n_input_tokens > self._max_length:
|
||||||
@ -236,97 +292,76 @@ class InferenceSession:
|
|||||||
|
|
||||||
server_idx = 0
|
server_idx = 0
|
||||||
block_idx = 0
|
block_idx = 0
|
||||||
recovery_until = -1 # Recovery mode is disabled until a failure happens
|
while block_idx < self.num_blocks:
|
||||||
while block_idx < n_blocks:
|
|
||||||
for attempt_no in itertools.count():
|
for attempt_no in itertools.count():
|
||||||
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
||||||
span = None
|
server_session = None
|
||||||
try:
|
try:
|
||||||
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
if not self._server_sessions or attempt_no >= 1:
|
||||||
# If there is a failed server session, this code closes it
|
self._update_sequence(server_idx, block_idx, attempt_no)
|
||||||
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
||||||
|
|
||||||
n_prev_spans = len(self._chosen_spans)
|
server_session = self._server_sessions[server_idx]
|
||||||
update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
|
inputs = server_session.step(
|
||||||
if attempt_no >= 1 and update_end > recovery_until:
|
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
|
||||||
logger.info(
|
|
||||||
f"Due to a server failure, remote attention caches "
|
|
||||||
f"from block {block_idx} to {update_end} will be regenerated"
|
|
||||||
)
|
|
||||||
recovery_until = max(recovery_until, update_end)
|
|
||||||
|
|
||||||
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
|
|
||||||
# make_sequence() could return a longer sequence
|
|
||||||
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
|
|
||||||
updated_sessions = self._enter_server_sessions(updated_spans)
|
|
||||||
logger.debug(
|
|
||||||
f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If there is a failed span, this code replaces it, otherwise it just adds new ones
|
|
||||||
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
|
|
||||||
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
|
|
||||||
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
|
|
||||||
self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
|
|
||||||
len(updated_spans) - 1
|
|
||||||
)
|
|
||||||
assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
|
|
||||||
f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
|
|
||||||
f"{len(self._server_inputs)} inputs"
|
|
||||||
)
|
|
||||||
|
|
||||||
session = self._server_sessions[server_idx]
|
|
||||||
span = self._chosen_spans[server_idx]
|
|
||||||
|
|
||||||
if self._server_inputs[server_idx] is None:
|
|
||||||
self._server_inputs[server_idx] = inputs
|
|
||||||
elif self._server_inputs[server_idx].shape[1] == self._position:
|
|
||||||
self._server_inputs[server_idx] = torch.cat(
|
|
||||||
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
|
|
||||||
)
|
|
||||||
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
|
|
||||||
f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
|
|
||||||
f"position={self._position} n_input_tokens={n_input_tokens}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session.stepped:
|
|
||||||
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
||||||
else:
|
|
||||||
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
||||||
|
|
||||||
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
||||||
assert (
|
|
||||||
inputs.shape == outputs.shape
|
|
||||||
), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
|
|
||||||
|
|
||||||
inputs = outputs
|
|
||||||
server_idx += 1
|
server_idx += 1
|
||||||
block_idx = span.end
|
block_idx = server_session.span.end
|
||||||
self._sequence_manager.on_request_success(span.peer_id)
|
self._sequence_manager.on_request_success(server_session.span.peer_id)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if span is not None and not isinstance(e, P2PHandlerError):
|
self._sequence_manager.on_request_failure(
|
||||||
self._sequence_manager.on_request_failure(span.peer_id)
|
server_session.span.peer_id if server_session is not None else None
|
||||||
|
)
|
||||||
|
if attempt_no + 1 == self._sequence_manager.config.max_retries:
|
||||||
|
raise
|
||||||
delay = self._sequence_manager.get_retry_delay(attempt_no)
|
delay = self._sequence_manager.get_retry_delay(attempt_no)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Caught exception when running inference from block {block_idx} "
|
f"Caught exception when running inference via {server_session.span if server_session is not None else None} "
|
||||||
f"(retry in {delay:.0f} sec): {repr(e)}"
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
||||||
)
|
)
|
||||||
maybe_log_traceback(e)
|
maybe_log_traceback(e)
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
||||||
self._position += n_input_tokens
|
self._position += n_input_tokens
|
||||||
inputs = inputs[:, -n_input_tokens:]
|
outputs = inputs[:, -n_input_tokens:]
|
||||||
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
|
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
|
||||||
|
# If there is a failed server session, this code closes it
|
||||||
|
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
||||||
|
|
||||||
|
n_prev_spans = len(self._server_sessions)
|
||||||
|
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
|
||||||
|
if attempt_no >= 1:
|
||||||
|
logger.info(
|
||||||
|
f"Due to a server failure, remote attention caches "
|
||||||
|
f"from block {block_idx} to {update_end} will be regenerated"
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_spans = self._sequence_manager.make_sequence(
|
||||||
|
block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length
|
||||||
|
)
|
||||||
|
# make_sequence() could return a longer sequence
|
||||||
|
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
|
||||||
|
updated_sessions = self._enter_server_sessions(updated_spans)
|
||||||
|
logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers")
|
||||||
|
|
||||||
|
# If there is a failed span, this code replaces it, otherwise it just adds new ones
|
||||||
|
if server_idx < n_prev_spans:
|
||||||
|
updated_sessions[0].history = self._server_sessions[server_idx].history
|
||||||
|
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
|
||||||
|
|
||||||
|
# Update links to the next server session for direct server-to-server communication via rpc_push()
|
||||||
|
for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):
|
||||||
|
self._server_sessions[i].next_session = self._server_sessions[i + 1]
|
||||||
|
|
||||||
def close(self, *exc_details):
|
def close(self, *exc_details):
|
||||||
"""Finish a given inference session, close the underlying connection"""
|
"""Finish a given inference session, close the underlying connection"""
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
self._server_inputs.clear()
|
|
||||||
self._exit_server_sessions(self._server_sessions)
|
self._exit_server_sessions(self._server_sessions)
|
||||||
self._server_sessions.clear()
|
self._server_sessions.clear()
|
||||||
self._chosen_spans.clear()
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
|
|
||||||
def __exit__(self, *exc_details):
|
def __exit__(self, *exc_details):
|
||||||
|
84
src/petals/client/lm_head.py
Normal file
84
src/petals/client/lm_head.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import dataclasses
|
||||||
|
import platform
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from hivemind import get_logger
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LMHeadConfig:
|
||||||
|
# This settings matter for running the client with dtype bfloat16 on CPU.
|
||||||
|
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
|
||||||
|
use_chunked_forward: Union[str, bool] = "auto"
|
||||||
|
chunked_forward_step: int = 16384
|
||||||
|
|
||||||
|
|
||||||
|
class LMHead(nn.Module):
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not config.tie_word_embeddings:
|
||||||
|
self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
else:
|
||||||
|
self.weight = None # Will be set to get_input_embeddings().weight during loading the model
|
||||||
|
self.bias = None
|
||||||
|
self.in_features = config.hidden_size # Similar to nn.Linear attributes
|
||||||
|
self.out_features = config.vocab_size
|
||||||
|
|
||||||
|
self.use_chunked_forward = config.use_chunked_forward
|
||||||
|
if self.use_chunked_forward == "auto":
|
||||||
|
if platform.machine() == "x86_64":
|
||||||
|
# Import of cpufeature may crash on non-x86_64 machines
|
||||||
|
from cpufeature import CPUFeature
|
||||||
|
|
||||||
|
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
||||||
|
# Otherwise, it's ~8x slower.
|
||||||
|
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
||||||
|
else:
|
||||||
|
self.use_chunked_forward = True
|
||||||
|
self.chunked_forward_step = config.chunked_forward_step
|
||||||
|
self._bf16_warning_shown = False
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if (
|
||||||
|
self.weight.dtype in [torch.float16, torch.bfloat16]
|
||||||
|
and self.weight.device.type == "cpu"
|
||||||
|
and self.use_chunked_forward
|
||||||
|
):
|
||||||
|
lm_logits = self.chunked_forward(hidden_states)
|
||||||
|
else:
|
||||||
|
# Switch dtype in case word_embeddings are fp16/bf16
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
lm_logits = F.linear(hidden_states, self.weight)
|
||||||
|
return lm_logits
|
||||||
|
|
||||||
|
def chunked_forward(self, hidden_states):
|
||||||
|
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
||||||
|
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
|
||||||
|
"""
|
||||||
|
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
|
||||||
|
|
||||||
|
if not self._bf16_warning_shown:
|
||||||
|
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
|
||||||
|
logger.warning(
|
||||||
|
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
|
||||||
|
"Consider loading the model with torch_dtype='float32'"
|
||||||
|
)
|
||||||
|
self._bf16_warning_shown = True
|
||||||
|
|
||||||
|
hidden_states = hidden_states.float()
|
||||||
|
output = torch.empty(*hidden_states.shape[:-1], self.out_features)
|
||||||
|
|
||||||
|
for i in range(0, self.out_features, self.chunked_forward_step):
|
||||||
|
chunk = self.weight[i : i + self.chunked_forward_step].float()
|
||||||
|
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
|
||||||
|
return output
|
84
src/petals/client/ptune.py
Normal file
84
src/petals/client/ptune.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import dataclasses
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from petals.utils.misc import DUMMY
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PTuneConfig:
|
||||||
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
||||||
|
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
|
||||||
|
|
||||||
|
|
||||||
|
class PTuneMixin:
|
||||||
|
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
|
||||||
|
|
||||||
|
def init_prompts(self, config: PretrainedConfig) -> None:
|
||||||
|
if config.tuning_mode and "ptune" in config.tuning_mode:
|
||||||
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
||||||
|
self.pre_seq_len = config.pre_seq_len
|
||||||
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||||
|
|
||||||
|
with force_non_empty_weights():
|
||||||
|
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
|
||||||
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
||||||
|
if config.tuning_mode == "deep_ptune":
|
||||||
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
||||||
|
self.pre_seq_len,
|
||||||
|
config.num_hidden_layers * config.hidden_size,
|
||||||
|
# ^-- TODO: should be num_hidden_layers - 1
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
elif config.tuning_mode:
|
||||||
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
||||||
|
|
||||||
|
def get_prompt(self, batch_size):
|
||||||
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
||||||
|
prompts = self.prompt_embeddings(prefix_tokens)
|
||||||
|
|
||||||
|
if self.config.tuning_mode == "deep_ptune":
|
||||||
|
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
||||||
|
intermediate_prompts = intermediate_prompts.view(
|
||||||
|
batch_size,
|
||||||
|
self.pre_seq_len,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.hidden_size
|
||||||
|
# TODO: should be num_hidden_layers - 1
|
||||||
|
)
|
||||||
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
||||||
|
else:
|
||||||
|
intermediate_prompts = DUMMY
|
||||||
|
|
||||||
|
dtype = self.word_embeddings.weight.dtype
|
||||||
|
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
_original_register_parameter = nn.Module.register_parameter
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def force_non_empty_weights():
|
||||||
|
"""
|
||||||
|
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
||||||
|
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
||||||
|
The transformers library should replace all meta tensors by empty tensors by itself
|
||||||
|
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
||||||
|
|
||||||
|
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
possibly_patched_register_parameter = nn.Module.register_parameter
|
||||||
|
nn.Module.register_parameter = _original_register_parameter
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
nn.Module.register_parameter = possibly_patched_register_parameter
|
@ -13,52 +13,53 @@ from hivemind.proto import runtime_pb2
|
|||||||
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
|
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
|
||||||
from hivemind.utils.streaming import split_for_streaming
|
from hivemind.utils.streaming import split_for_streaming
|
||||||
|
|
||||||
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||||
from petals.data_structures import ModuleUID, RPCInfo
|
from petals.data_structures import ModuleUID, RPCInfo
|
||||||
|
|
||||||
|
|
||||||
async def _forward_unary(
|
async def _forward_unary(
|
||||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
||||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
||||||
timeout=timeout,
|
timeout=config.request_timeout,
|
||||||
)
|
)
|
||||||
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
||||||
|
|
||||||
|
|
||||||
async def _backward_unary(
|
async def _backward_unary(
|
||||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
||||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
||||||
timeout=timeout,
|
timeout=config.request_timeout,
|
||||||
)
|
)
|
||||||
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
||||||
|
|
||||||
|
|
||||||
async def _forward_stream(
|
async def _forward_stream(
|
||||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
parts = (
|
parts = (
|
||||||
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
||||||
for tensor in serialized_tensors
|
for tensor in serialized_tensors
|
||||||
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
||||||
)
|
)
|
||||||
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
|
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout)
|
||||||
outputs = aiter_with_timeout(outputs, timeout)
|
outputs = aiter_with_timeout(outputs, config.request_timeout)
|
||||||
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
|
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
|
||||||
|
|
||||||
|
|
||||||
async def _backward_stream(
|
async def _backward_stream(
|
||||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
parts = (
|
parts = (
|
||||||
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
||||||
for tensor in serialized_tensors
|
for tensor in serialized_tensors
|
||||||
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
||||||
)
|
)
|
||||||
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
|
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout)
|
||||||
grad_inputs = aiter_with_timeout(grad_inputs, timeout)
|
grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout)
|
||||||
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
|
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +68,7 @@ async def run_remote_forward(
|
|||||||
stub: StubBase,
|
stub: StubBase,
|
||||||
rpc_info: RPCInfo,
|
rpc_info: RPCInfo,
|
||||||
*inputs: torch.Tensor,
|
*inputs: torch.Tensor,
|
||||||
timeout: float,
|
config: SequenceManagerConfig,
|
||||||
metadata: Optional[bytes] = None,
|
metadata: Optional[bytes] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, ...]:
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
@ -108,8 +109,9 @@ async def run_remote_forward(
|
|||||||
|
|
||||||
# call RPC on remote server
|
# call RPC on remote server
|
||||||
size = sum(t.element_size() * t.nelement() for t in inputs)
|
size = sum(t.element_size() * t.nelement() for t in inputs)
|
||||||
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
|
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
|
||||||
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
|
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
|
||||||
|
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
|
||||||
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
|
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
|
||||||
|
|
||||||
|
|
||||||
@ -120,7 +122,7 @@ async def run_remote_backward(
|
|||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
grad_outputs: List[torch.Tensor],
|
grad_outputs: List[torch.Tensor],
|
||||||
*extra_tensors: torch.Tensor,
|
*extra_tensors: torch.Tensor,
|
||||||
timeout: float,
|
config: SequenceManagerConfig,
|
||||||
metadata: Optional[bytes] = None,
|
metadata: Optional[bytes] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Sequence[torch.Tensor]:
|
) -> Sequence[torch.Tensor]:
|
||||||
@ -150,6 +152,7 @@ async def run_remote_backward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
|
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
|
||||||
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary
|
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
|
||||||
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
|
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
|
||||||
|
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
|
||||||
return deserialized_grad_inputs
|
return deserialized_grad_inputs
|
||||||
|
@ -16,7 +16,7 @@ from petals.utils.generation_algorithms import (
|
|||||||
)
|
)
|
||||||
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RemoteGenerationMixin:
|
class RemoteGenerationMixin:
|
||||||
@ -41,10 +41,11 @@ class RemoteGenerationMixin:
|
|||||||
|
|
||||||
return self.transformer.h.inference_session(**kwargs)
|
return self.transformer.h.inference_session(**kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
*,
|
||||||
do_sample: Optional[bool] = None,
|
do_sample: Optional[bool] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@ -59,9 +60,7 @@ class RemoteGenerationMixin:
|
|||||||
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
provided_constraints: List[ABCBloomConstraint] = [],
|
||||||
num_return_sequences: Optional[int] = None,
|
num_return_sequences: Optional[int] = None,
|
||||||
*,
|
|
||||||
session: Optional[InferenceSession] = None,
|
session: Optional[InferenceSession] = None,
|
||||||
**model_kwargs,
|
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""
|
"""
|
||||||
Generates sequences of token ids for models with a language modeling head.
|
Generates sequences of token ids for models with a language modeling head.
|
||||||
@ -80,19 +79,9 @@ class RemoteGenerationMixin:
|
|||||||
:param max_new_tokens: The maximum number of tokens to generate.
|
:param max_new_tokens: The maximum number of tokens to generate.
|
||||||
:param decoding_algorithm: The decoding algorithm to use.
|
:param decoding_algorithm: The decoding algorithm to use.
|
||||||
:param provided_constraints: A list of constraints to use.
|
:param provided_constraints: A list of constraints to use.
|
||||||
:param model_kwargs: Additional arguments to pass to the model.
|
|
||||||
:param num_return_sequences: How many hypothesis from the beam will be in output.
|
:param num_return_sequences: How many hypothesis from the beam will be in output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert (
|
|
||||||
model_kwargs.get("logits_processor", None) is None
|
|
||||||
), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
|
|
||||||
assert (
|
|
||||||
model_kwargs.get("logits_wrapper", None) is None
|
|
||||||
), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
|
|
||||||
assert (
|
|
||||||
model_kwargs.get("stopping_criteria", None) is None
|
|
||||||
), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
|
|
||||||
prefix_length = 0 if inputs is None else inputs.size(1)
|
prefix_length = 0 if inputs is None else inputs.size(1)
|
||||||
prefix_length += self.config.pre_seq_len
|
prefix_length += self.config.pre_seq_len
|
||||||
|
|
||||||
@ -107,17 +96,18 @@ class RemoteGenerationMixin:
|
|||||||
elif max_length is None and max_new_tokens is not None:
|
elif max_length is None and max_new_tokens is not None:
|
||||||
max_length = prefix_length + max_new_tokens
|
max_length = prefix_length + max_new_tokens
|
||||||
|
|
||||||
if num_beams > 1 and session is not None:
|
resuming_session = session is not None and session.token_ids
|
||||||
|
if num_beams > 1 and resuming_session:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Reusing inference session in .generate() along with beam search is not supported yet"
|
"Resuming inference session in .generate() along with beam search is not supported yet"
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
||||||
if session is not None and session.token_ids:
|
if resuming_session:
|
||||||
inputs = torch.cat([session.token_ids[-1], inputs], dim=1)
|
inputs = torch.cat([session.token_ids[-1], inputs], dim=1)
|
||||||
else:
|
else:
|
||||||
if session is not None and session.token_ids:
|
if resuming_session:
|
||||||
inputs = session.token_ids[-1]
|
inputs = session.token_ids[-1]
|
||||||
else:
|
else:
|
||||||
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
||||||
@ -131,9 +121,7 @@ class RemoteGenerationMixin:
|
|||||||
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
||||||
else:
|
else:
|
||||||
if top_k is not None or top_p is not None or repetition_penalty is not None:
|
if top_k is not None or top_p is not None or repetition_penalty is not None:
|
||||||
logger.warning(
|
raise ValueError("Passing top_k, top_p, or repetition_penalty requires passing do_sample=True")
|
||||||
"You passed top_k, top_p, or repetition_penalty but did pass do_sample=True. Running greedy sampling"
|
|
||||||
)
|
|
||||||
decoding_algorithm = GreedyAlgorithm()
|
decoding_algorithm = GreedyAlgorithm()
|
||||||
|
|
||||||
if num_beams > 1:
|
if num_beams > 1:
|
||||||
@ -182,19 +170,25 @@ class RemoteGenerationMixin:
|
|||||||
seq_idx = outputs[0].size(1)
|
seq_idx = outputs[0].size(1)
|
||||||
hypo_ids = torch.arange(outputs[0].size(0))
|
hypo_ids = torch.arange(outputs[0].size(0))
|
||||||
while True:
|
while True:
|
||||||
embs = self.transformer.word_embeddings(outputs[-1])
|
hidden_state = self.transformer.word_embeddings(outputs[-1])
|
||||||
intermediate_prompts = None
|
intermediate_prompts = None
|
||||||
if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
||||||
prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
|
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
|
||||||
embs = torch.cat([prompts, embs], dim=1)
|
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
||||||
embs = self.transformer.word_embeddings_layernorm(embs)
|
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
||||||
hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
|
||||||
|
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
||||||
|
|
||||||
hidden_state = self.transformer.ln_f(hidden_state)
|
hidden_state = self.transformer.ln_f(hidden_state)
|
||||||
lm_logits = self.lm_head(hidden_state)
|
lm_logits = self.lm_head(hidden_state)
|
||||||
|
|
||||||
for constraint in constraints:
|
for constraint in constraints:
|
||||||
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
||||||
token_ids = torch.cat(session.token_ids, dim=1) if session.token_ids else torch.empty(batch_size, 0, dtype=torch.int64)
|
token_ids = (
|
||||||
|
torch.cat(session.token_ids, dim=1)
|
||||||
|
if session.token_ids
|
||||||
|
else torch.empty(batch_size, 0, dtype=torch.int64)
|
||||||
|
)
|
||||||
last_token_id, hypo_ids = decoding_algorithm(token_ids, lm_logits)
|
last_token_id, hypo_ids = decoding_algorithm(token_ids, lm_logits)
|
||||||
|
|
||||||
# If some samples were padded, change only these samples
|
# If some samples were padded, change only these samples
|
||||||
@ -217,6 +211,8 @@ class RemoteGenerationMixin:
|
|||||||
|
|
||||||
outputs = torch.cat(outputs, dim=-1)
|
outputs = torch.cat(outputs, dim=-1)
|
||||||
|
|
||||||
|
if resuming_session:
|
||||||
|
outputs = outputs[:, 1:]
|
||||||
if num_beams > 1:
|
if num_beams > 1:
|
||||||
pre_return_idx = [
|
pre_return_idx = [
|
||||||
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|
||||||
@ -233,7 +229,6 @@ class RemoteGenerationMixin:
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
provided_constraints: List[ABCBloomConstraint] = [],
|
||||||
**model_kwargs,
|
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""
|
"""
|
||||||
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
|
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
|
||||||
@ -251,7 +246,6 @@ class RemoteGenerationMixin:
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
decoding_algorithm=GreedyAlgorithm(),
|
decoding_algorithm=GreedyAlgorithm(),
|
||||||
provided_constraints=provided_constraints,
|
provided_constraints=provided_constraints,
|
||||||
**model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@ -264,7 +258,6 @@ class RemoteGenerationMixin:
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
provided_constraints: List[ABCBloomConstraint] = [],
|
||||||
**model_kwargs,
|
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""
|
"""
|
||||||
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
|
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
|
||||||
@ -278,7 +271,6 @@ class RemoteGenerationMixin:
|
|||||||
:param: pad_token_id: The id of the padding token.
|
:param: pad_token_id: The id of the padding token.
|
||||||
:param: eos_token_id: The id of the end of sentence token.
|
:param: eos_token_id: The id of the end of sentence token.
|
||||||
:param: provided_constraints: A list of constraints to use.
|
:param: provided_constraints: A list of constraints to use.
|
||||||
:param: model_kwargs: Additional kwargs to pass to the model.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.generate(
|
return self.generate(
|
||||||
@ -288,7 +280,6 @@ class RemoteGenerationMixin:
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
|
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
|
||||||
provided_constraints=provided_constraints,
|
provided_constraints=provided_constraints,
|
||||||
**model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
@ -299,7 +290,6 @@ class RemoteGenerationMixin:
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
provided_constraints: List[ABCBloomConstraint] = [],
|
||||||
**model_kwargs,
|
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
"""
|
"""
|
||||||
Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
||||||
@ -310,7 +300,6 @@ class RemoteGenerationMixin:
|
|||||||
:param pad_token_id: The id of the padding token.
|
:param pad_token_id: The id of the padding token.
|
||||||
:param eos_token_id: The id of the end of sentence token.
|
:param eos_token_id: The id of the end of sentence token.
|
||||||
:param provided_constraints: A list of constraints to use.
|
:param provided_constraints: A list of constraints to use.
|
||||||
:param: model_kwargs: Additional kwargs to pass to the model.
|
|
||||||
"""
|
"""
|
||||||
decoding_algorithm = BeamSearchAlgorithm(
|
decoding_algorithm = BeamSearchAlgorithm(
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
@ -324,7 +313,6 @@ class RemoteGenerationMixin:
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
decoding_algorithm=decoding_algorithm,
|
decoding_algorithm=decoding_algorithm,
|
||||||
provided_constraints=provided_constraints,
|
provided_constraints=provided_constraints,
|
||||||
**model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def beam_sample(
|
def beam_sample(
|
||||||
@ -334,7 +322,6 @@ class RemoteGenerationMixin:
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
provided_constraints: List[ABCBloomConstraint] = [],
|
||||||
**model_kwargs,
|
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -345,7 +332,6 @@ class RemoteGenerationMixin:
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
provided_constraints: List[ABCBloomConstraint] = [],
|
||||||
**model_kwargs,
|
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -1,264 +0,0 @@
|
|||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import hivemind
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from hivemind.utils.logging import get_logger
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
||||||
from transformers.models.bloom import (
|
|
||||||
BloomConfig,
|
|
||||||
BloomForCausalLM,
|
|
||||||
BloomForSequenceClassification,
|
|
||||||
BloomModel,
|
|
||||||
BloomPreTrainedModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
from petals.bloom.modeling_utils import LMHead
|
|
||||||
from petals.client.remote_generation import RemoteGenerationMixin
|
|
||||||
from petals.client.remote_sequential import RemoteSequential
|
|
||||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
||||||
from petals.utils.misc import DUMMY
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomConfig(BloomConfig):
|
|
||||||
"""
|
|
||||||
A bloom config that contains information about DHT peers.
|
|
||||||
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
|
||||||
"""
|
|
||||||
|
|
||||||
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
|
||||||
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
||||||
daemon_startup_timeout: int = 30
|
|
||||||
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
||||||
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
||||||
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
||||||
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
|
||||||
request_timeout: int = 30 # a number of seconds for waiting result from each node
|
|
||||||
|
|
||||||
|
|
||||||
original_register_parameter = nn.Module.register_parameter
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def force_non_empty_weights():
|
|
||||||
"""
|
|
||||||
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
|
||||||
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
|
||||||
The transformers library should replace all meta tensors by empty tensors by itself
|
|
||||||
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
|
||||||
|
|
||||||
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
possibly_patched_register_parameter = nn.Module.register_parameter
|
|
||||||
nn.Module.register_parameter = original_register_parameter
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
nn.Module.register_parameter = possibly_patched_register_parameter
|
|
||||||
|
|
||||||
|
|
||||||
class _LowCPUMemoryMixin:
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
|
|
||||||
if low_cpu_mem_usage is None:
|
|
||||||
low_cpu_mem_usage = True
|
|
||||||
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
|
|
||||||
|
|
||||||
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
|
||||||
"low_cpu_mem_usage(`bool`, *optional*)",
|
|
||||||
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
|
|
||||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
|
||||||
r"^(intermediate_)?prompt_embeddings\.weight$",
|
|
||||||
]
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
|
||||||
assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
|
|
||||||
|
|
||||||
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
|
||||||
super().__init__(config)
|
|
||||||
assert len(self.h) == 0
|
|
||||||
config.n_layer = n_layer
|
|
||||||
|
|
||||||
dht = (
|
|
||||||
config.dht
|
|
||||||
if config.dht is not None
|
|
||||||
else hivemind.DHT(
|
|
||||||
initial_peers=config.initial_peers,
|
|
||||||
client_mode=True,
|
|
||||||
num_workers=n_layer,
|
|
||||||
startup_timeout=config.daemon_startup_timeout,
|
|
||||||
start=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
|
|
||||||
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
|
|
||||||
|
|
||||||
# Forbid accumulate grads for embeddings and layernorm
|
|
||||||
self.set_requires_grad(False)
|
|
||||||
|
|
||||||
if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
||||||
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
||||||
self.pre_seq_len = config.pre_seq_len
|
|
||||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
||||||
|
|
||||||
with force_non_empty_weights():
|
|
||||||
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
|
||||||
logger.info(
|
|
||||||
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
|
||||||
"to increase ptune quality"
|
|
||||||
)
|
|
||||||
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
|
||||||
if config.tuning_mode == "deep_ptune":
|
|
||||||
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
||||||
self.pre_seq_len,
|
|
||||||
config.num_hidden_layers * config.hidden_size,
|
|
||||||
# ^-- TODO: should be num_hidden_layers - 1
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
elif config.tuning_mode:
|
|
||||||
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
||||||
|
|
||||||
def set_requires_grad(self, value):
|
|
||||||
for p in self.parameters():
|
|
||||||
p.requires_grad = value
|
|
||||||
|
|
||||||
def get_prompt(self, batch_size):
|
|
||||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
||||||
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
||||||
prompts = self.prompt_embeddings(prefix_tokens)
|
|
||||||
|
|
||||||
if self.config.tuning_mode == "deep_ptune":
|
|
||||||
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
|
||||||
intermediate_prompts = intermediate_prompts.view(
|
|
||||||
batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
|
||||||
)
|
|
||||||
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
||||||
else:
|
|
||||||
intermediate_prompts = DUMMY
|
|
||||||
|
|
||||||
dtype = self.word_embeddings.weight.dtype
|
|
||||||
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if not (v is None or v is False):
|
|
||||||
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
||||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
||||||
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
|
||||||
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
|
||||||
else:
|
|
||||||
hidden_states = self.h(hidden_states)
|
|
||||||
|
|
||||||
# Remove prefix
|
|
||||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
||||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
||||||
|
|
||||||
# Add last hidden state
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
hidden_states = hidden_states.view(output_shape)
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=None,
|
|
||||||
hidden_states=None,
|
|
||||||
attentions=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
|
|
||||||
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = (
|
|
||||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
BloomPreTrainedModel.__init__(self, config)
|
|
||||||
self.transformer = DistributedBloomModel(config)
|
|
||||||
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.transformer.word_embeddings
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
return None
|
|
||||||
return self.lm_head
|
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
|
||||||
assert isinstance(new_embeddings, nn.Embedding)
|
|
||||||
self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
|
||||||
assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
|
||||||
|
|
||||||
def set_output_embeddings(self, new_lm_head: nn.Linear):
|
|
||||||
with torch.no_grad():
|
|
||||||
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
||||||
self.lm_head.bias[...] = new_lm_head.bias
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
|
|
||||||
_keys_to_ignore_on_load_missing = (
|
|
||||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
)
|
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
|
||||||
|
|
||||||
def __init__(self, config: DistributedBloomConfig):
|
|
||||||
BloomPreTrainedModel.__init__(self, config)
|
|
||||||
self.num_labels = config.num_labels
|
|
||||||
|
|
||||||
self.transformer = DistributedBloomModel(config)
|
|
||||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
@ -3,18 +3,16 @@ from __future__ import annotations
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from hivemind import DHT, P2P, get_logger
|
from hivemind import DHT, get_logger
|
||||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import petals.client
|
|
||||||
from petals.client.inference_session import InferenceSession
|
from petals.client.inference_session import InferenceSession
|
||||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
|
||||||
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
|
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
|
||||||
from petals.data_structures import UID_DELIMITER
|
from petals.data_structures import UID_DELIMITER
|
||||||
from petals.utils.misc import DUMMY
|
from petals.utils.misc import DUMMY
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RemoteSequential(nn.Module):
|
class RemoteSequential(nn.Module):
|
||||||
@ -24,32 +22,28 @@ class RemoteSequential(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: petals.client.DistributedBloomConfig,
|
config: SequenceManagerConfig,
|
||||||
dht: DHT,
|
*,
|
||||||
dht_prefix: Optional[str] = None,
|
|
||||||
p2p: Optional[P2P] = None,
|
|
||||||
sequence_manager: Optional[RemoteSequenceManager] = None,
|
sequence_manager: Optional[RemoteSequenceManager] = None,
|
||||||
|
dht: Optional[DHT] = None,
|
||||||
|
start_block: Optional[int] = None,
|
||||||
|
end_block: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.dht = dht
|
|
||||||
self.dht_prefix = dht_prefix or config.dht_prefix
|
|
||||||
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
|
|
||||||
|
|
||||||
num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
|
assert sequence_manager is None or (
|
||||||
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
|
dht is None and start_block is None and end_block is None
|
||||||
|
), "`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`"
|
||||||
if sequence_manager is None:
|
if sequence_manager is None:
|
||||||
logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
|
if start_block is None:
|
||||||
self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, start=True, **kwargs)
|
start_block = 0
|
||||||
self.is_subsequence = False
|
if end_block is None:
|
||||||
else:
|
end_block = self.config.num_hidden_layers
|
||||||
logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
|
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
|
||||||
if kwargs:
|
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
|
||||||
logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided")
|
self.sequence_manager = sequence_manager
|
||||||
self.sequence_manager = sequence_manager
|
|
||||||
assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
|
|
||||||
self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
|
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
|
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
|
||||||
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
|
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
|
||||||
@ -58,23 +52,10 @@ class RemoteSequential(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
|
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
|
||||||
assert isinstance(ix, (int, slice))
|
return RemoteSequential(
|
||||||
if isinstance(ix, int):
|
self.config,
|
||||||
return RemoteTransformerBlock(
|
sequence_manager=self.sequence_manager[ix],
|
||||||
self.config,
|
)
|
||||||
self.dht,
|
|
||||||
dht_prefix=self.dht_prefix,
|
|
||||||
p2p=self.p2p,
|
|
||||||
sequence_manager=self.sequence_manager[ix],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return RemoteSequential(
|
|
||||||
self.config,
|
|
||||||
self.dht,
|
|
||||||
dht_prefix=self.dht_prefix,
|
|
||||||
p2p=self.p2p,
|
|
||||||
sequence_manager=self.sequence_manager[ix],
|
|
||||||
)
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for block_index in range(len(self)):
|
for block_index in range(len(self)):
|
||||||
@ -84,22 +65,7 @@ class RemoteSequential(nn.Module):
|
|||||||
return len(self.sequence_manager)
|
return len(self.sequence_manager)
|
||||||
|
|
||||||
def inference_session(self, **kwargs) -> InferenceSession:
|
def inference_session(self, **kwargs) -> InferenceSession:
|
||||||
return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
|
return InferenceSession(self.sequence_manager, **kwargs)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
|
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
|
||||||
|
|
||||||
|
|
||||||
class RemoteTransformerBlock(RemoteSequential):
|
|
||||||
"""Single transformer block hosted by swarm
|
|
||||||
|
|
||||||
This class is deprecated and kept for backward compatibility.
|
|
||||||
It will be removed soon in favor of using ``RemoteSequential`` directly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
assert len(self) == 1, "Remote Block is a sequence size 1"
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"{self.sequence_manager.block_uids[0]}"
|
|
||||||
|
@ -6,7 +6,7 @@ from hivemind import get_logger
|
|||||||
|
|
||||||
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
|
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -27,14 +27,14 @@ class RemoteSequenceInfo:
|
|||||||
block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated
|
block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated
|
||||||
spans_by_priority: List[RemoteSpanInfo]
|
spans_by_priority: List[RemoteSpanInfo]
|
||||||
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
|
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
|
||||||
last_updated_time: float
|
last_updated_time: Optional[float]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
|
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
|
||||||
block_uids = tuple(block_uids)
|
block_uids = tuple(block_uids)
|
||||||
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
|
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
|
||||||
empty_spans = tuple([] for _ in range(len(block_uids)))
|
empty_spans = tuple([] for _ in range(len(block_uids)))
|
||||||
return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf"))
|
return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None)
|
||||||
|
|
||||||
def __getitem__(self, ix: slice):
|
def __getitem__(self, ix: slice):
|
||||||
assert isinstance(ix, slice)
|
assert isinstance(ix, slice)
|
||||||
@ -73,11 +73,16 @@ class RemoteSequenceInfo:
|
|||||||
active_spans = {}
|
active_spans = {}
|
||||||
for block_index, info in enumerate(block_infos):
|
for block_index, info in enumerate(block_infos):
|
||||||
if info is not None:
|
if info is not None:
|
||||||
for peer_id, server in info.servers.items():
|
for peer_id, server_info in info.servers.items():
|
||||||
if server.state != ServerState.ONLINE:
|
if server_info.state != ServerState.ONLINE:
|
||||||
continue
|
continue
|
||||||
if peer_id not in active_spans:
|
if peer_id not in active_spans:
|
||||||
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
|
active_spans[peer_id] = RemoteSpanInfo(
|
||||||
|
peer_id=peer_id,
|
||||||
|
start=block_index,
|
||||||
|
end=block_index + 1,
|
||||||
|
server_info=server_info,
|
||||||
|
)
|
||||||
else: # peer_id in active_spans
|
else: # peer_id in active_spans
|
||||||
active_spans[peer_id].end = block_index + 1
|
active_spans[peer_id].end = block_index + 1
|
||||||
|
|
||||||
@ -91,7 +96,7 @@ class RemoteSequenceInfo:
|
|||||||
closed_spans.append(active_spans.pop(peer_id))
|
closed_spans.append(active_spans.pop(peer_id))
|
||||||
assert not active_spans, f"spans: {active_spans}"
|
assert not active_spans, f"spans: {active_spans}"
|
||||||
|
|
||||||
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
|
closed_spans.sort(key=lambda span: span.length, reverse=True)
|
||||||
|
|
||||||
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
|
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
|
||||||
for span in closed_spans:
|
for span in closed_spans:
|
||||||
|
@ -1,28 +1,71 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
|
||||||
from weakref import WeakMethod
|
from weakref import WeakMethod
|
||||||
|
|
||||||
|
import dijkstar
|
||||||
|
import numpy as np
|
||||||
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
|
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
|
||||||
from hivemind.dht.node import Blacklist
|
from hivemind.dht.node import Blacklist
|
||||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||||
from hivemind.p2p import P2PHandlerError
|
|
||||||
from hivemind.proto import runtime_pb2
|
from hivemind.proto import runtime_pb2
|
||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
import petals.dht_utils
|
import petals.dht_utils
|
||||||
from petals.client.routing.sequence_info import RemoteSequenceInfo
|
from petals.client.routing.sequence_info import RemoteSequenceInfo
|
||||||
from petals.client.routing.spending_policy import NoSpendingPolicy
|
from petals.client.routing.spending_policy import NoSpendingPolicy
|
||||||
|
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||||
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
|
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
|
||||||
from petals.server.handler import TransformerConnectionHandler
|
from petals.server.handler import TransformerConnectionHandler
|
||||||
|
from petals.utils.ping import PingAggregator
|
||||||
|
from petals.utils.random import sample_up_to
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SequenceManagerConfig:
|
||||||
|
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
|
||||||
|
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
|
||||||
|
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
||||||
|
|
||||||
|
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
|
||||||
|
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
||||||
|
use_server_to_server: bool = True # Use direct server-to-server communication
|
||||||
|
|
||||||
|
connect_timeout: float = 5 # timeout for opening a connection
|
||||||
|
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
|
||||||
|
update_period: float = 60 # refresh DHT information once in this many seconds
|
||||||
|
|
||||||
|
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
|
||||||
|
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
|
||||||
|
max_backoff: float = 60 # limit maximal sleep time between retries to this value
|
||||||
|
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
||||||
|
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
|
||||||
|
|
||||||
|
max_pinged: int = 3 # max servers to ping from each sequence side, per update
|
||||||
|
ping_timeout: float = 2 # max time to wait for pings, per update
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SequenceManagerState:
|
||||||
|
p2p: P2P = None
|
||||||
|
sequence_info: Optional[RemoteSequenceInfo] = None
|
||||||
|
rpc_info: Optional[dict] = None
|
||||||
|
banned_peers: Optional[Blacklist] = None
|
||||||
|
|
||||||
|
def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState:
|
||||||
|
return dataclasses.replace(self, sequence_info=self.sequence_info[ix])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.sequence_info)
|
||||||
|
|
||||||
|
|
||||||
class RemoteSequenceManager:
|
class RemoteSequenceManager:
|
||||||
@ -33,91 +76,244 @@ class RemoteSequenceManager:
|
|||||||
Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
|
Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
|
||||||
To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
|
To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
|
||||||
|
|
||||||
:param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
|
|
||||||
:param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
|
|
||||||
:param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
|
|
||||||
:param update_period: by default, refresh DHT information once in this many seconds
|
|
||||||
:param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
|
|
||||||
:param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
|
|
||||||
:param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
|
|
||||||
:param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
|
|
||||||
:param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
||||||
:param start: start the background thread (see the note below). If false, you will need to start it manually.
|
|
||||||
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
|
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
|
||||||
running redundant sequence managers for the same set of layers.
|
running redundant sequence managers for the same set of layers.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dht: DHT,
|
config: SequenceManagerConfig,
|
||||||
block_uids: Sequence[ModuleUID],
|
block_uids: Sequence[ModuleUID],
|
||||||
p2p: P2P,
|
*,
|
||||||
update_period: float = 30,
|
dht: Optional[DHT] = None,
|
||||||
request_timeout: float = 30,
|
state: Optional[SequenceManagerState] = None,
|
||||||
min_backoff: float = 1,
|
|
||||||
ban_timeout: float = 15,
|
|
||||||
sequence_info: Optional[RemoteSequenceInfo] = None,
|
|
||||||
rpc_info: Optional[dict] = None,
|
|
||||||
banned_peers: Optional[Blacklist] = None,
|
|
||||||
*, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
|
|
||||||
start: bool,
|
|
||||||
):
|
):
|
||||||
|
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
|
||||||
|
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
||||||
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
||||||
self.dht, self.p2p = dht, p2p
|
|
||||||
self.request_timeout, self.ban_timeout, self.min_backoff = request_timeout, ban_timeout, min_backoff
|
self.config = config
|
||||||
|
if state is None:
|
||||||
|
state = SequenceManagerState()
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
if dht is None:
|
||||||
|
dht = DHT(
|
||||||
|
initial_peers=config.initial_peers,
|
||||||
|
client_mode=True,
|
||||||
|
num_workers=32,
|
||||||
|
startup_timeout=config.daemon_startup_timeout,
|
||||||
|
start=True,
|
||||||
|
)
|
||||||
|
assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
|
||||||
|
self.dht = dht
|
||||||
|
|
||||||
|
if state.p2p is None:
|
||||||
|
state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
||||||
|
|
||||||
self.lock_changes = threading.Lock()
|
self.lock_changes = threading.Lock()
|
||||||
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
|
self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update))
|
||||||
|
self._thread_start_lock = threading.Lock()
|
||||||
self.policy = NoSpendingPolicy()
|
self.policy = NoSpendingPolicy()
|
||||||
self._rpc_info = rpc_info
|
|
||||||
self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
|
|
||||||
|
|
||||||
if sequence_info is None:
|
self.ping_aggregator = PingAggregator(dht)
|
||||||
self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
|
||||||
self.update(wait=False)
|
if state.banned_peers is None:
|
||||||
else:
|
state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
|
||||||
self.sequence_info = sequence_info
|
if state.sequence_info is None:
|
||||||
assert block_uids == sequence_info.block_uids
|
state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
||||||
|
|
||||||
|
if state.sequence_info.last_updated_time is not None:
|
||||||
|
assert block_uids == state.sequence_info.block_uids
|
||||||
self._thread.ready.set() # no need to await the first dht fetch
|
self._thread.ready.set() # no need to await the first dht fetch
|
||||||
|
self._need_latest_infos = True
|
||||||
|
|
||||||
if start:
|
def make_sequence(
|
||||||
self.run_in_background()
|
self,
|
||||||
|
start_index: int = 0,
|
||||||
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
|
end_index: Optional[int] = None,
|
||||||
"""
|
*,
|
||||||
Starts the updater thread in a background. if await_ready, this method will wait until sequence manager
|
mode: str,
|
||||||
is ready to process incoming requests or for :timeout: seconds max.
|
cache_tokens_needed: Optional[int] = None,
|
||||||
"""
|
) -> List[RemoteSpanInfo]:
|
||||||
self._thread.start()
|
|
||||||
if await_ready:
|
|
||||||
self._thread.ready.wait(timeout)
|
|
||||||
|
|
||||||
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
|
|
||||||
"""
|
"""
|
||||||
Form a sequence of remote servers that collectively serve all consecutive layers
|
Form a sequence of remote servers that collectively serve all consecutive layers
|
||||||
|
|
||||||
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
||||||
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
||||||
|
:param mode: one of ["max_throughput", "min_latency"]
|
||||||
"""
|
"""
|
||||||
if not self.is_alive():
|
with self._thread_start_lock:
|
||||||
logger.error("Using a sequence manager that is not running: it has either crashed or never started")
|
if not self.is_alive():
|
||||||
|
self._thread.start()
|
||||||
if not self.ready.is_set():
|
if not self.ready.is_set():
|
||||||
logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
|
|
||||||
self.update(wait=True) # this will await an existing update or trigger a new one (if not updating)
|
self.update(wait=True) # this will await an existing update or trigger a new one (if not updating)
|
||||||
|
|
||||||
end_index = end_index if end_index is not None else len(self)
|
end_index = end_index if end_index is not None else len(self)
|
||||||
|
|
||||||
|
if mode == "min_latency":
|
||||||
|
span_sequence = self._make_sequence_with_min_latency(
|
||||||
|
start_index, end_index, cache_tokens_needed=cache_tokens_needed
|
||||||
|
)
|
||||||
|
elif mode == "max_throughput":
|
||||||
|
span_sequence = self._make_sequence_with_max_throughput(start_index, end_index)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected mode {mode}")
|
||||||
|
|
||||||
|
if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"):
|
||||||
|
route_repr = " => ".join(
|
||||||
|
[f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]
|
||||||
|
)
|
||||||
|
logger.info(f"Route found: {route_repr}")
|
||||||
|
return span_sequence
|
||||||
|
|
||||||
|
def _make_sequence_with_min_latency(
|
||||||
|
self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int]
|
||||||
|
) -> List[RemoteSpanInfo]:
|
||||||
|
if start_index == end_index:
|
||||||
|
return []
|
||||||
|
|
||||||
|
with self.lock_changes:
|
||||||
|
missing_blocks = [
|
||||||
|
block_idx
|
||||||
|
for block_idx in range(start_index, end_index)
|
||||||
|
if not self.state.sequence_info.spans_containing_block[block_idx]
|
||||||
|
]
|
||||||
|
if missing_blocks:
|
||||||
|
raise MissingBlocksError(missing_blocks)
|
||||||
|
server_infos = {
|
||||||
|
span.peer_id: span.server_info
|
||||||
|
for block_idx in range(start_index, end_index)
|
||||||
|
for span in self.state.sequence_info.spans_containing_block[block_idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed)
|
||||||
|
|
||||||
|
path = dijkstar.find_path(graph, "start", "end")
|
||||||
|
logger.debug(f"Path info: {path}")
|
||||||
|
if start_index == 0 and end_index == len(self):
|
||||||
|
logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec")
|
||||||
|
|
||||||
|
span_sequence = []
|
||||||
|
for peer_id, block_idx in path.nodes[1:-1]:
|
||||||
|
if not span_sequence or span_sequence[-1].peer_id != peer_id:
|
||||||
|
span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id]))
|
||||||
|
else:
|
||||||
|
span_sequence[-1].end = block_idx
|
||||||
|
|
||||||
|
# Remove empty spans that can appear if we don't force to go to the end of each server and network delay
|
||||||
|
# don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors
|
||||||
|
span_sequence = [span for span in span_sequence if span.length > 0]
|
||||||
|
|
||||||
|
return span_sequence
|
||||||
|
|
||||||
|
def _build_inference_graph(
|
||||||
|
self,
|
||||||
|
start_index: int,
|
||||||
|
end_index: int,
|
||||||
|
*,
|
||||||
|
cache_tokens_needed: Optional[int],
|
||||||
|
overhead_delay: float = 0.018, # Serialization overhead (empirically measured)
|
||||||
|
default_inference_rps: float = 300, # If inference RPS unknown
|
||||||
|
alloc_delay: float = 10, # If not enough cache left, we penalize the edge
|
||||||
|
) -> dijkstar.Graph:
|
||||||
|
missing_blocks = [
|
||||||
|
block_idx
|
||||||
|
for block_idx in range(start_index, end_index)
|
||||||
|
if not self.state.sequence_info.spans_containing_block[block_idx]
|
||||||
|
]
|
||||||
|
if missing_blocks:
|
||||||
|
raise MissingBlocksError(missing_blocks)
|
||||||
|
|
||||||
|
client_server_rtts = self.ping_aggregator.to_dict()
|
||||||
|
|
||||||
|
graph = dijkstar.Graph()
|
||||||
|
|
||||||
|
# Clent -> server network delays
|
||||||
|
for span in self.state.sequence_info.spans_containing_block[start_index]:
|
||||||
|
delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
|
||||||
|
delay += overhead_delay
|
||||||
|
if not self._has_cache_for(span, cache_tokens_needed):
|
||||||
|
delay += alloc_delay
|
||||||
|
graph.add_edge("start", (span.peer_id, start_index), delay)
|
||||||
|
|
||||||
|
# Server -> client network delays
|
||||||
|
for span in self.state.sequence_info.spans_containing_block[end_index - 1]:
|
||||||
|
delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
|
||||||
|
graph.add_edge((span.peer_id, end_index), "end", delay)
|
||||||
|
|
||||||
|
# Server -> server network delays
|
||||||
|
for block_idx in range(start_index + 1, end_index):
|
||||||
|
for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]:
|
||||||
|
if cur_span.end != block_idx:
|
||||||
|
# If we choose a server, we force to go to the end of it before switching to a new one
|
||||||
|
# to avoid O(N^2) graphs for N servers
|
||||||
|
continue
|
||||||
|
|
||||||
|
for next_span in self.state.sequence_info.spans_containing_block[block_idx]:
|
||||||
|
rtt = None
|
||||||
|
if cur_span.server_info.next_pings is not None:
|
||||||
|
rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58())
|
||||||
|
delay = self._rtt_to_delay(rtt)
|
||||||
|
delay += overhead_delay
|
||||||
|
if not self._has_cache_for(next_span, cache_tokens_needed):
|
||||||
|
delay += alloc_delay
|
||||||
|
graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay)
|
||||||
|
|
||||||
|
# Compute delays
|
||||||
|
for span in self.state.sequence_info.spans_by_priority:
|
||||||
|
for block_idx in range(max(span.start, start_index), min(span.end, end_index)):
|
||||||
|
inference_rps = span.server_info.inference_rps
|
||||||
|
if inference_rps is None:
|
||||||
|
inference_rps = default_inference_rps
|
||||||
|
graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), 1.0 / inference_rps)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rtt_to_delay(
|
||||||
|
rtt: float,
|
||||||
|
*,
|
||||||
|
default_delay: float = 0.15, # If network delay unknown
|
||||||
|
max_delay: float = 5, # If unreachable, we don't want to discard the edge completely
|
||||||
|
) -> float:
|
||||||
|
if rtt is None:
|
||||||
|
return default_delay
|
||||||
|
return min(rtt / 2, max_delay)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool:
|
||||||
|
if cache_tokens_needed is None or span.server_info.cache_tokens_left is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through
|
||||||
|
# this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,
|
||||||
|
# so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.
|
||||||
|
# This is okay since false positives are more costly than false negatives here.
|
||||||
|
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
|
||||||
|
|
||||||
|
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
|
||||||
|
client_server_rtts = self.ping_aggregator.to_dict()
|
||||||
|
|
||||||
span_sequence = []
|
span_sequence = []
|
||||||
current_index = start_index
|
current_index = start_index
|
||||||
while current_index < end_index:
|
while current_index < end_index:
|
||||||
candidate_spans = self.sequence_info.spans_containing_block[current_index]
|
candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
|
||||||
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
|
if not candidate_spans:
|
||||||
|
raise MissingBlocksError(current_index)
|
||||||
|
|
||||||
|
# We choose longer servers to minimize the number of hops but leave some randomization
|
||||||
|
# to distribute the load. We also exclude servers known to be unreachable.
|
||||||
|
eps = 1e-6
|
||||||
|
span_weights = np.array(
|
||||||
|
[span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
|
||||||
|
|
||||||
assert chosen_span.start <= current_index < chosen_span.end
|
assert chosen_span.start <= current_index < chosen_span.end
|
||||||
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
|
span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
|
||||||
current_index = chosen_span.end
|
current_index = chosen_span.end
|
||||||
|
|
||||||
route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence])
|
|
||||||
logger.debug(f"Route found: {route_repr}")
|
|
||||||
return span_sequence
|
return span_sequence
|
||||||
|
|
||||||
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
|
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
|
||||||
@ -125,63 +321,74 @@ class RemoteSequenceManager:
|
|||||||
assert isinstance(ix, (int, slice))
|
assert isinstance(ix, (int, slice))
|
||||||
if not isinstance(ix, slice):
|
if not isinstance(ix, slice):
|
||||||
ix = slice(int(ix), int(ix) + 1, 1)
|
ix = slice(int(ix), int(ix) + 1, 1)
|
||||||
return type(self)(
|
return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])
|
||||||
self.dht,
|
|
||||||
self.block_uids[ix],
|
|
||||||
self.p2p,
|
|
||||||
update_period=self._thread.update_period,
|
|
||||||
request_timeout=self.request_timeout,
|
|
||||||
ban_timeout=self.ban_timeout,
|
|
||||||
min_backoff=self.min_backoff,
|
|
||||||
sequence_info=self.sequence_info[ix],
|
|
||||||
rpc_info=self._rpc_info,
|
|
||||||
banned_peers=self.banned_peers,
|
|
||||||
start=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, *, wait: bool):
|
def update(self, *, wait: bool):
|
||||||
"""Run an asynchronous update in background as soon as possible"""
|
"""Run an asynchronous update in background as soon as possible"""
|
||||||
self.ready.clear() # TODO this should be a separate event
|
self.ready.clear()
|
||||||
self._thread.trigger.set()
|
self._thread.trigger.set()
|
||||||
if wait:
|
if wait:
|
||||||
self.ready.wait()
|
self.ready.wait()
|
||||||
|
|
||||||
def _update(self):
|
def _update(self):
|
||||||
"""Perform an immediate and synchronous refresh, may take time"""
|
"""Perform an immediate and synchronous refresh, may take time"""
|
||||||
for attempt_no in itertools.count():
|
|
||||||
try:
|
|
||||||
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
|
||||||
self.dht, self.block_uids, expiration_time=float("inf")
|
|
||||||
)
|
|
||||||
for block_info in new_block_infos:
|
|
||||||
if not block_info:
|
|
||||||
continue
|
|
||||||
for peer_id in tuple(block_info.servers.keys()):
|
|
||||||
if peer_id in self.banned_peers:
|
|
||||||
logger.debug(f"Ignoring banned {peer_id} for block {block_info.uid}")
|
|
||||||
block_info.servers.pop(peer_id)
|
|
||||||
|
|
||||||
with self.lock_changes:
|
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
||||||
self.sequence_info.update_(new_block_infos)
|
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
|
||||||
missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
|
)
|
||||||
if missing_blocks:
|
|
||||||
raise MissingBlocksError(f"no servers holding blocks {missing_blocks}")
|
|
||||||
self.ready.set() # if there is an active server for every block, we may begin running
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
for block_info in new_block_infos:
|
||||||
delay = self.get_retry_delay(attempt_no)
|
if not block_info:
|
||||||
logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
|
continue
|
||||||
maybe_log_traceback(e)
|
|
||||||
time.sleep(delay)
|
|
||||||
|
|
||||||
def on_request_failure(self, peer_id: PeerID):
|
# Apply whitelist, if defined
|
||||||
|
if self.config.allowed_servers is not None:
|
||||||
|
block_info.servers = {
|
||||||
|
peer_id: server_info
|
||||||
|
for peer_id, server_info in block_info.servers.items()
|
||||||
|
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove temporarily banned peers, unless there are no peers left
|
||||||
|
valid_servers = {
|
||||||
|
peer_id: server_info
|
||||||
|
for peer_id, server_info in block_info.servers.items()
|
||||||
|
if peer_id not in self.state.banned_peers
|
||||||
|
}
|
||||||
|
if len(valid_servers) < len(block_info.servers):
|
||||||
|
if valid_servers:
|
||||||
|
logger.debug(
|
||||||
|
f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
|
||||||
|
)
|
||||||
|
block_info.servers = valid_servers
|
||||||
|
else:
|
||||||
|
# If we blacklisted all servers, the error may actually be client-caused
|
||||||
|
logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
|
||||||
|
|
||||||
|
with self.lock_changes:
|
||||||
|
self.state.sequence_info.update_(new_block_infos)
|
||||||
|
|
||||||
|
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
|
||||||
|
middle_servers = [
|
||||||
|
span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
|
||||||
|
]
|
||||||
|
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
|
||||||
|
|
||||||
|
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
|
||||||
|
pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
|
||||||
|
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
|
||||||
|
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
|
||||||
|
|
||||||
|
self.ready.set()
|
||||||
|
|
||||||
|
def on_request_failure(self, peer_id: Optional[PeerID]):
|
||||||
"""remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
|
"""remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
|
||||||
logger.info(f"Peer {peer_id} did not respond, banning it temporarily")
|
if peer_id is not None:
|
||||||
self.banned_peers.register_failure(peer_id)
|
logger.debug(f"Peer {peer_id} did not respond, banning it temporarily")
|
||||||
|
self.state.banned_peers.register_failure(peer_id)
|
||||||
with self.lock_changes:
|
with self.lock_changes:
|
||||||
should_update = False
|
should_update = False
|
||||||
for info in self.sequence_info.block_infos:
|
for info in self.state.sequence_info.block_infos:
|
||||||
info.servers.pop(peer_id, None)
|
info.servers.pop(peer_id, None)
|
||||||
if not info.servers:
|
if not info.servers:
|
||||||
should_update = True
|
should_update = True
|
||||||
@ -191,7 +398,7 @@ class RemoteSequenceManager:
|
|||||||
|
|
||||||
def on_request_success(self, peer_id: PeerID):
|
def on_request_success(self, peer_id: PeerID):
|
||||||
"""if peer has a failure streak, clear that streak"""
|
"""if peer has a failure streak, clear that streak"""
|
||||||
self.banned_peers.register_success(peer_id)
|
self.state.banned_peers.register_success(peer_id)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.block_uids)
|
return len(self.block_uids)
|
||||||
@ -206,51 +413,58 @@ class RemoteSequenceManager:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def block_uids(self):
|
def block_uids(self):
|
||||||
return self.sequence_info.block_uids
|
return self.state.sequence_info.block_uids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rpc_info(self):
|
def rpc_info(self):
|
||||||
"""Return the rpc_info queried from one of the servers that hold the first block"""
|
"""Return the rpc_info queried from one of the servers that hold the first block"""
|
||||||
if self._rpc_info is None:
|
if self.state.rpc_info is not None:
|
||||||
for attempt_no in itertools.count():
|
return self.state.rpc_info
|
||||||
peer_id = None
|
|
||||||
try:
|
|
||||||
if not self.ready.is_set():
|
|
||||||
self.update(wait=True)
|
|
||||||
|
|
||||||
active_servers = [
|
with self._thread_start_lock:
|
||||||
peer_id
|
if not self.is_alive():
|
||||||
for peer_id, server in self.sequence_info.block_infos[0].servers.items()
|
self._thread.start()
|
||||||
if server.state == ServerState.ONLINE
|
|
||||||
]
|
|
||||||
if not active_servers:
|
|
||||||
raise MissingBlocksError("no servers holding the first block are online")
|
|
||||||
peer_id = random.choice(active_servers)
|
|
||||||
|
|
||||||
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
|
for attempt_no in itertools.count():
|
||||||
outputs = RemoteExpertWorker.run_coroutine(
|
peer_id = None
|
||||||
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
|
try:
|
||||||
)
|
if not self.ready.is_set():
|
||||||
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
self.update(wait=True)
|
||||||
self.on_request_success(peer_id)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if peer_id is not None and not isinstance(e, P2PHandlerError):
|
|
||||||
self.on_request_failure(peer_id)
|
|
||||||
delay = self.get_retry_delay(attempt_no)
|
|
||||||
logger.warning(
|
|
||||||
f"Caught exception when gathering information from peer {peer_id} "
|
|
||||||
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
||||||
)
|
|
||||||
maybe_log_traceback(e)
|
|
||||||
time.sleep(delay)
|
|
||||||
|
|
||||||
return self._rpc_info
|
active_servers = [
|
||||||
|
peer_id
|
||||||
|
for peer_id, server in self.state.sequence_info.block_infos[0].servers.items()
|
||||||
|
if server.state == ServerState.ONLINE
|
||||||
|
]
|
||||||
|
if not active_servers:
|
||||||
|
raise MissingBlocksError(0)
|
||||||
|
peer_id = random.choice(active_servers)
|
||||||
|
|
||||||
|
stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id)
|
||||||
|
outputs = RemoteExpertWorker.run_coroutine(
|
||||||
|
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout)
|
||||||
|
)
|
||||||
|
self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
||||||
|
self.on_request_success(peer_id)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
self.on_request_failure(peer_id)
|
||||||
|
if attempt_no + 1 == self.config.max_retries:
|
||||||
|
raise
|
||||||
|
delay = self.get_retry_delay(attempt_no)
|
||||||
|
logger.warning(
|
||||||
|
f"Caught exception when gathering information from peer {peer_id} "
|
||||||
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
||||||
|
)
|
||||||
|
maybe_log_traceback(e)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
return self.state.rpc_info
|
||||||
|
|
||||||
def get_retry_delay(self, attempt_no: int) -> float:
|
def get_retry_delay(self, attempt_no: int) -> float:
|
||||||
if attempt_no == 0:
|
if attempt_no == 0:
|
||||||
return 0
|
return 0
|
||||||
return self.min_backoff * 2 ** (attempt_no - 1)
|
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
|
||||||
|
|
||||||
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
|
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@ -259,7 +473,7 @@ class RemoteSequenceManager:
|
|||||||
:param kwargs: additional request context, such as remote peer ID
|
:param kwargs: additional request context, such as remote peer ID
|
||||||
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
|
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
|
||||||
"""
|
"""
|
||||||
return dict(points=self.policy.get_points(protocol, *args, **kwargs))
|
return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self._thread.shutdown()
|
self._thread.shutdown()
|
||||||
@ -271,18 +485,11 @@ class _SequenceManagerUpdateThread(threading.Thread):
|
|||||||
self.ref_update_manager = ref_update_manager
|
self.ref_update_manager = ref_update_manager
|
||||||
self.ready = threading.Event()
|
self.ready = threading.Event()
|
||||||
self.trigger = threading.Event()
|
self.trigger = threading.Event()
|
||||||
self.last_update_time = -float("inf")
|
|
||||||
self.update_period = update_period
|
self.update_period = update_period
|
||||||
self.should_shutdown = False
|
self.should_shutdown = False
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
while not self.should_shutdown:
|
while not self.should_shutdown:
|
||||||
self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time)))
|
|
||||||
|
|
||||||
if self.should_shutdown:
|
|
||||||
logger.debug(f"{self.__class__.__name__} is shutting down")
|
|
||||||
break
|
|
||||||
|
|
||||||
update_manager = self.ref_update_manager()
|
update_manager = self.ref_update_manager()
|
||||||
if update_manager is None:
|
if update_manager is None:
|
||||||
logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
|
logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
|
||||||
@ -296,16 +503,18 @@ class _SequenceManagerUpdateThread(threading.Thread):
|
|||||||
finally:
|
finally:
|
||||||
del update_manager
|
del update_manager
|
||||||
|
|
||||||
|
self.trigger.wait(self.update_period)
|
||||||
|
|
||||||
logger.debug(f"{self.__class__.__name__} thread exited")
|
logger.debug(f"{self.__class__.__name__} thread exited")
|
||||||
|
|
||||||
def shutdown(self, timeout: Optional[float] = None):
|
def shutdown(self, timeout: Optional[float] = None):
|
||||||
self.should_shutdown = True
|
self.should_shutdown = True
|
||||||
self.trigger.set()
|
self.trigger.set()
|
||||||
self.join(timeout)
|
if self.is_alive():
|
||||||
|
self.join(timeout)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.is_alive():
|
self.shutdown()
|
||||||
self.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_log_traceback(exc: Exception):
|
def maybe_log_traceback(exc: Exception):
|
||||||
@ -313,6 +522,11 @@ def maybe_log_traceback(exc: Exception):
|
|||||||
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
class MissingBlocksError(Exception):
|
class MissingBlocksError(RuntimeError):
|
||||||
def __repr__(self):
|
def __init__(self, block_indices: Union[int, Sequence[int]]):
|
||||||
return self.args[0]
|
super().__init__(
|
||||||
|
f"No servers holding blocks {block_indices} are online. "
|
||||||
|
f"You can check the public swarm's state at https://health.petals.dev "
|
||||||
|
f"If there are not enough servers, please connect your GPU: "
|
||||||
|
f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity "
|
||||||
|
)
|
||||||
|
@ -3,14 +3,12 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import List, Optional, Sequence, Tuple
|
from typing import List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from hivemind import MSGPackSerializer
|
from hivemind import MSGPackSerializer
|
||||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||||
from hivemind.p2p import P2PHandlerError
|
|
||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
||||||
@ -19,7 +17,7 @@ from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
|
|||||||
from petals.server.handler import TransformerConnectionHandler
|
from petals.server.handler import TransformerConnectionHandler
|
||||||
from petals.utils.misc import DUMMY, is_dummy
|
from petals.utils.misc import DUMMY, is_dummy
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
MAX_TOKENS_IN_BATCH = 1024
|
MAX_TOKENS_IN_BATCH = 1024
|
||||||
|
|
||||||
@ -61,14 +59,14 @@ async def sequential_forward(
|
|||||||
span = None
|
span = None
|
||||||
try:
|
try:
|
||||||
if not sequences or attempt_no >= 1:
|
if not sequences or attempt_no >= 1:
|
||||||
sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
|
sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="max_throughput"))
|
||||||
# make_sequence() could return a longer sequence
|
# make_sequence() could return a longer sequence
|
||||||
sequences[-1].end = min(sequences[-1].end, end_index)
|
sequences[-1].end = min(sequences[-1].end, end_index)
|
||||||
logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
|
logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
|
||||||
|
|
||||||
span = sequences.popleft()
|
span = sequences.popleft()
|
||||||
|
|
||||||
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
|
||||||
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
||||||
|
|
||||||
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
||||||
@ -78,7 +76,7 @@ async def sequential_forward(
|
|||||||
stub,
|
stub,
|
||||||
sequence_manager.rpc_info,
|
sequence_manager.rpc_info,
|
||||||
*inputs_and_prompts,
|
*inputs_and_prompts,
|
||||||
timeout=sequence_manager.request_timeout,
|
config=sequence_manager.config,
|
||||||
metadata=MSGPackSerializer.dumps(metadata),
|
metadata=MSGPackSerializer.dumps(metadata),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,12 +92,12 @@ async def sequential_forward(
|
|||||||
sequence_manager.on_request_success(span.peer_id)
|
sequence_manager.on_request_success(span.peer_id)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if span is not None and not isinstance(e, P2PHandlerError):
|
sequence_manager.on_request_failure(span.peer_id if span is not None else None)
|
||||||
sequence_manager.on_request_failure(span.peer_id)
|
if attempt_no + 1 == sequence_manager.config.max_retries:
|
||||||
|
raise
|
||||||
delay = sequence_manager.get_retry_delay(attempt_no)
|
delay = sequence_manager.get_retry_delay(attempt_no)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Caught exception when running forward from block {block_idx} "
|
f"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}"
|
||||||
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
||||||
)
|
)
|
||||||
maybe_log_traceback(e)
|
maybe_log_traceback(e)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
@ -152,7 +150,7 @@ async def sequential_backward(
|
|||||||
span = forward_sequences.pop()
|
span = forward_sequences.pop()
|
||||||
|
|
||||||
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
||||||
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
|
||||||
metadata = sequence_manager.get_request_metadata(
|
metadata = sequence_manager.get_request_metadata(
|
||||||
"rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
|
"rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
|
||||||
)
|
)
|
||||||
@ -163,7 +161,7 @@ async def sequential_backward(
|
|||||||
inputs,
|
inputs,
|
||||||
grad_outputs,
|
grad_outputs,
|
||||||
prompts[span.start : span.end],
|
prompts[span.start : span.end],
|
||||||
timeout=sequence_manager.request_timeout,
|
config=sequence_manager.config,
|
||||||
metadata=MSGPackSerializer.dumps(metadata),
|
metadata=MSGPackSerializer.dumps(metadata),
|
||||||
)
|
)
|
||||||
grad_outputs = [grad_outputs]
|
grad_outputs = [grad_outputs]
|
||||||
@ -171,12 +169,12 @@ async def sequential_backward(
|
|||||||
sequence_manager.on_request_success(span.peer_id)
|
sequence_manager.on_request_success(span.peer_id)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if span is not None and not isinstance(e, P2PHandlerError):
|
sequence_manager.on_request_failure(span.peer_id if span is not None else None)
|
||||||
sequence_manager.on_request_failure(span.peer_id)
|
if attempt_no + 1 == sequence_manager.config.max_retries:
|
||||||
|
raise
|
||||||
delay = sequence_manager.get_retry_delay(attempt_no)
|
delay = sequence_manager.get_retry_delay(attempt_no)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Caught exception when running backward between blocks {span.start}-{span.end} "
|
f"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}"
|
||||||
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
||||||
)
|
)
|
||||||
maybe_log_traceback(e)
|
maybe_log_traceback(e)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
@ -1,6 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
PUBLIC_INITIAL_PEERS = [
|
PUBLIC_INITIAL_PEERS = [
|
||||||
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
# IPv4 DNS addresses
|
||||||
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||||
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||||
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
# IPv6 DNS addresses
|
||||||
|
"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||||
|
"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||||
|
# Reserved IPs
|
||||||
|
"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||||
|
"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# The reachability API is currently used only when connecting to the public swarm
|
||||||
|
REACHABILITY_API_URL = "https://health.petals.dev"
|
||||||
|
|
||||||
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
import dataclasses
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import pydantic
|
||||||
from hivemind import PeerID
|
from hivemind import PeerID
|
||||||
|
from hivemind.moe.expert_uid import ExpertUID
|
||||||
|
|
||||||
|
from petals.server.memory_cache import Handle
|
||||||
|
|
||||||
ModuleUID = str
|
ModuleUID = str
|
||||||
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
|
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
|
||||||
@ -15,13 +19,42 @@ class ServerState(Enum):
|
|||||||
ONLINE = 2
|
ONLINE = 2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pydantic.dataclasses.dataclass
|
||||||
class ServerInfo:
|
class ServerInfo:
|
||||||
state: ServerState
|
state: ServerState
|
||||||
throughput: float
|
throughput: RPS
|
||||||
|
|
||||||
|
public_name: Optional[str] = None
|
||||||
|
version: Optional[str] = None
|
||||||
|
|
||||||
|
network_rps: Optional[RPS] = None
|
||||||
|
forward_rps: Optional[RPS] = None
|
||||||
|
inference_rps: Optional[RPS] = None
|
||||||
|
|
||||||
|
adapters: Sequence[str] = ()
|
||||||
|
torch_dtype: Optional[str] = None
|
||||||
|
quant_type: Optional[str] = None
|
||||||
|
using_relay: Optional[bool] = None
|
||||||
|
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
|
||||||
|
next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
|
||||||
|
|
||||||
|
def to_tuple(self) -> Tuple[int, float, dict]:
|
||||||
|
extra_info = dataclasses.asdict(self)
|
||||||
|
del extra_info["state"], extra_info["throughput"]
|
||||||
|
return (self.state.value, self.throughput, extra_info)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tuple(cls, source: tuple):
|
||||||
|
state, throughput = source[:2]
|
||||||
|
extra_info = source[2] if len(source) > 2 else {}
|
||||||
|
# pydantic will validate existing fields and ignore extra ones
|
||||||
|
return cls(state=ServerState(state), throughput=throughput, **extra_info)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class RemoteModuleInfo:
|
class RemoteModuleInfo:
|
||||||
"""A remote module that is served by one or more servers"""
|
"""A remote module that is served by one or more servers"""
|
||||||
|
|
||||||
@ -29,13 +62,26 @@ class RemoteModuleInfo:
|
|||||||
servers: Dict[PeerID, ServerInfo]
|
servers: Dict[PeerID, ServerInfo]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class RemoteSpanInfo:
|
class RemoteSpanInfo:
|
||||||
"""A chain of remote blocks served by one specific remote peer"""
|
"""A chain of remote blocks served by one specific remote peer"""
|
||||||
|
|
||||||
|
peer_id: PeerID
|
||||||
start: int
|
start: int
|
||||||
end: int
|
end: int
|
||||||
peer_id: PeerID
|
server_info: ServerInfo
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self):
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
|
||||||
RPCInfo = Dict[str, Any]
|
RPCInfo = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class InferenceMetadata:
|
||||||
|
uid: ExpertUID
|
||||||
|
prefix_length: int
|
||||||
|
cache_handles: Tuple[Handle, ...]
|
||||||
|
active_adapter: Optional[str]
|
||||||
|
@ -8,22 +8,19 @@ from functools import partial
|
|||||||
from typing import Dict, List, Optional, Sequence, Union
|
from typing import Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from hivemind.dht import DHT, DHTNode, DHTValue
|
from hivemind.dht import DHT, DHTNode, DHTValue
|
||||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
||||||
from hivemind.p2p import PeerID
|
from hivemind.p2p import PeerID
|
||||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
||||||
|
|
||||||
import petals.client
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
|
||||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def declare_active_modules(
|
def declare_active_modules(
|
||||||
dht: DHT,
|
dht: DHT,
|
||||||
uids: Sequence[ModuleUID],
|
uids: Sequence[ModuleUID],
|
||||||
|
server_info: ServerInfo,
|
||||||
expiration_time: DHTExpiration,
|
expiration_time: DHTExpiration,
|
||||||
state: ServerState,
|
|
||||||
throughput: float,
|
|
||||||
wait: bool = True,
|
wait: bool = True,
|
||||||
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
||||||
"""
|
"""
|
||||||
@ -41,14 +38,9 @@ def declare_active_modules(
|
|||||||
uids = list(uids)
|
uids = list(uids)
|
||||||
for uid in uids:
|
for uid in uids:
|
||||||
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
||||||
|
|
||||||
return dht.run_coroutine(
|
return dht.run_coroutine(
|
||||||
partial(
|
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
|
||||||
_declare_active_modules,
|
|
||||||
uids=uids,
|
|
||||||
expiration_time=expiration_time,
|
|
||||||
state=state,
|
|
||||||
throughput=throughput,
|
|
||||||
),
|
|
||||||
return_future=not wait,
|
return_future=not wait,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,97 +49,52 @@ async def _declare_active_modules(
|
|||||||
dht: DHT,
|
dht: DHT,
|
||||||
node: DHTNode,
|
node: DHTNode,
|
||||||
uids: List[ModuleUID],
|
uids: List[ModuleUID],
|
||||||
|
server_info: ServerInfo,
|
||||||
expiration_time: DHTExpiration,
|
expiration_time: DHTExpiration,
|
||||||
state: ServerState,
|
|
||||||
throughput: float,
|
|
||||||
) -> Dict[ModuleUID, bool]:
|
) -> Dict[ModuleUID, bool]:
|
||||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||||
return await node.store_many(
|
return await node.store_many(
|
||||||
keys=uids,
|
keys=uids,
|
||||||
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
||||||
values=[(state.value, throughput)] * len(uids),
|
values=[server_info.to_tuple()] * len(uids),
|
||||||
expiration_time=expiration_time,
|
expiration_time=expiration_time,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_remote_sequence(
|
|
||||||
dht: DHT,
|
|
||||||
start: int,
|
|
||||||
stop: int,
|
|
||||||
config: petals.client.DistributedBloomConfig,
|
|
||||||
dht_prefix: Optional[str] = None,
|
|
||||||
return_future: bool = False,
|
|
||||||
) -> Union[petals.client.RemoteSequential, MPFuture]:
|
|
||||||
return RemoteExpertWorker.run_coroutine(
|
|
||||||
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_remote_sequence(
|
|
||||||
dht: DHT,
|
|
||||||
start: int,
|
|
||||||
stop: int,
|
|
||||||
config: petals.client.DistributedBloomConfig,
|
|
||||||
dht_prefix: Optional[str] = None,
|
|
||||||
) -> petals.client.RemoteSequential:
|
|
||||||
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
|
|
||||||
p2p = await dht.replicate_p2p()
|
|
||||||
manager = petals.client.RemoteSequenceManager(dht, uids, p2p, start=True)
|
|
||||||
return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
|
||||||
|
|
||||||
|
|
||||||
def get_remote_module(
|
|
||||||
dht: DHT,
|
|
||||||
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
||||||
config: petals.client.DistributedBloomConfig,
|
|
||||||
dht_prefix: Optional[str] = None,
|
|
||||||
return_future: bool = False,
|
|
||||||
) -> Union[Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]], MPFuture]:
|
|
||||||
"""
|
|
||||||
:param uid_or_uids: find one or more modules with these ids from across the DHT
|
|
||||||
:param config: model config, usually taken by .from_pretrained(MODEL_NAME)
|
|
||||||
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
|
||||||
:returns: a list of [RemoteTransformerBlock]
|
|
||||||
"""
|
|
||||||
return RemoteExpertWorker.run_coroutine(
|
|
||||||
_get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_remote_module(
|
|
||||||
dht: DHT,
|
|
||||||
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
||||||
config: petals.client.DistributedBloomConfig,
|
|
||||||
dht_prefix: Optional[str] = None,
|
|
||||||
) -> Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]]:
|
|
||||||
single_uid = isinstance(uid_or_uids, ModuleUID)
|
|
||||||
uids = [uid_or_uids] if single_uid else uid_or_uids
|
|
||||||
p2p = await dht.replicate_p2p()
|
|
||||||
managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p, start=True) for uid in uids)
|
|
||||||
modules = [
|
|
||||||
petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
|
|
||||||
for m in managers
|
|
||||||
]
|
|
||||||
return modules[0] if single_uid else modules
|
|
||||||
|
|
||||||
|
|
||||||
def get_remote_module_infos(
|
def get_remote_module_infos(
|
||||||
dht: DHT, uid_or_uids: Union[ModuleUID, Sequence[ModuleUID]], expiration_time: Optional[DHTExpiration] = None
|
dht: DHT,
|
||||||
) -> List[Optional[RemoteModuleInfo]]:
|
uids: Sequence[ModuleUID],
|
||||||
single_uid = isinstance(uid_or_uids, ModuleUID)
|
expiration_time: Optional[DHTExpiration] = None,
|
||||||
uids = [uid_or_uids] if single_uid else uid_or_uids
|
active_adapter: Optional[str] = None,
|
||||||
infos = dht.run_coroutine(
|
*,
|
||||||
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time),
|
latest: bool = False,
|
||||||
return_future=False,
|
return_future: bool = False,
|
||||||
|
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
|
||||||
|
return dht.run_coroutine(
|
||||||
|
partial(
|
||||||
|
_get_remote_module_infos,
|
||||||
|
uids=uids,
|
||||||
|
active_adapter=active_adapter,
|
||||||
|
expiration_time=expiration_time,
|
||||||
|
latest=latest,
|
||||||
|
),
|
||||||
|
return_future=return_future,
|
||||||
)
|
)
|
||||||
return infos[0] if single_uid else infos
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_remote_module_infos(
|
async def _get_remote_module_infos(
|
||||||
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
|
dht: DHT,
|
||||||
|
node: DHTNode,
|
||||||
|
uids: List[ModuleUID],
|
||||||
|
active_adapter: Optional[str],
|
||||||
|
expiration_time: Optional[DHTExpiration],
|
||||||
|
latest: bool,
|
||||||
) -> List[Optional[RemoteModuleInfo]]:
|
) -> List[Optional[RemoteModuleInfo]]:
|
||||||
if expiration_time is None:
|
if latest:
|
||||||
|
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
||||||
|
expiration_time = math.inf
|
||||||
|
elif expiration_time is None:
|
||||||
expiration_time = get_dht_time()
|
expiration_time = get_dht_time()
|
||||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||||
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
||||||
@ -157,23 +104,21 @@ async def _get_remote_module_infos(
|
|||||||
metadata = found[uid]
|
metadata = found[uid]
|
||||||
if metadata is None or not isinstance(metadata.value, dict):
|
if metadata is None or not isinstance(metadata.value, dict):
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
logger.error(f"Incorrect metadata for {uid}: {metadata}")
|
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
|
||||||
continue
|
continue
|
||||||
servers = {}
|
servers = {}
|
||||||
for peer_id, server_info in metadata.value.items():
|
for peer_id, server_info in metadata.value.items():
|
||||||
try:
|
try:
|
||||||
peer_id = PeerID.from_base58(peer_id)
|
peer_id = PeerID.from_base58(peer_id)
|
||||||
state, throughput = server_info.value
|
server_info = ServerInfo.from_tuple(server_info.value)
|
||||||
if not (
|
|
||||||
isinstance(state, int)
|
if active_adapter and active_adapter not in server_info.adapters:
|
||||||
and isinstance(throughput, float)
|
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
||||||
and math.isfinite(throughput)
|
continue
|
||||||
and throughput >= 0.0
|
|
||||||
):
|
servers[peer_id] = server_info
|
||||||
raise ValueError(f"Invalid server info: {server_info}")
|
|
||||||
servers[peer_id] = ServerInfo(ServerState(state), throughput)
|
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
||||||
if servers:
|
if servers:
|
||||||
modules[i] = RemoteModuleInfo(uid, servers)
|
modules[i] = RemoteModuleInfo(uid, servers)
|
||||||
return modules
|
return modules
|
||||||
|
2
src/petals/models/__init__.py
Normal file
2
src/petals/models/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from petals.models.bloom import *
|
||||||
|
from petals.models.llama import *
|
15
src/petals/models/bloom/__init__.py
Normal file
15
src/petals/models/bloom/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from petals.models.bloom.block import WrappedBloomBlock
|
||||||
|
from petals.models.bloom.config import DistributedBloomConfig
|
||||||
|
from petals.models.bloom.model import (
|
||||||
|
DistributedBloomForCausalLM,
|
||||||
|
DistributedBloomForSequenceClassification,
|
||||||
|
DistributedBloomModel,
|
||||||
|
)
|
||||||
|
from petals.utils.auto_config import register_model_classes
|
||||||
|
|
||||||
|
register_model_classes(
|
||||||
|
config=DistributedBloomConfig,
|
||||||
|
model=DistributedBloomModel,
|
||||||
|
model_for_causal_lm=DistributedBloomForCausalLM,
|
||||||
|
model_for_sequence_classification=DistributedBloomForSequenceClassification,
|
||||||
|
)
|
32
src/petals/models/bloom/block.py
Normal file
32
src/petals/models/bloom/block.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Bloom intermediate layer
|
||||||
|
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
||||||
|
See commit history for authorship.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedBloomBlock(BloomBlock):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
alibi: Optional[torch.Tensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert attention_mask is None, "Non-causal attention masks are not supported yet"
|
||||||
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
|
if alibi is None:
|
||||||
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
||||||
|
)
|
34
src/petals/models/bloom/config.py
Normal file
34
src/petals/models/bloom/config.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.bloom import BloomConfig
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||||
|
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||||
|
from petals.models.bloom.block import WrappedBloomBlock
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedBloomBlock
|
||||||
|
attn_class = BloomAttention
|
||||||
|
block_prefix = "h"
|
||||||
|
|
||||||
|
num_key_value_groups = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
|
||||||
|
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
# We need "-petals" for backward compatibility with Petals < 1.2.0
|
||||||
|
dht_prefix = str(model_name_or_path) + "-petals"
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
126
src/petals/models/bloom/model.py
Normal file
126
src/petals/models/bloom/model.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.bloom.config import DistributedBloomConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
||||||
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.h) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.h = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not (v is None or v is False):
|
||||||
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
|
||||||
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
||||||
|
else:
|
||||||
|
hidden_states = self.h(hidden_states)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig):
|
||||||
|
BloomPreTrainedModel.__init__(self, config)
|
||||||
|
self.transformer = DistributedBloomModel(config)
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedBloomConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedBloomConfig):
|
||||||
|
BloomPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.transformer = DistributedBloomModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
15
src/petals/models/llama/__init__.py
Normal file
15
src/petals/models/llama/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from petals.models.llama.block import WrappedLlamaBlock
|
||||||
|
from petals.models.llama.config import DistributedLlamaConfig
|
||||||
|
from petals.models.llama.model import (
|
||||||
|
DistributedLlamaForCausalLM,
|
||||||
|
DistributedLlamaForSequenceClassification,
|
||||||
|
DistributedLlamaModel,
|
||||||
|
)
|
||||||
|
from petals.utils.auto_config import register_model_classes
|
||||||
|
|
||||||
|
register_model_classes(
|
||||||
|
config=DistributedLlamaConfig,
|
||||||
|
model=DistributedLlamaModel,
|
||||||
|
model_for_causal_lm=DistributedLlamaForCausalLM,
|
||||||
|
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
|
||||||
|
)
|
91
src/petals/models/llama/block.py
Normal file
91
src/petals/models/llama/block.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
LLaMA intermediate layer
|
||||||
|
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
See commit history for authorship.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedLlamaBlock(LlamaDecoderLayer):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
past_key_value = layer_past
|
||||||
|
if past_key_value is not None:
|
||||||
|
past_key_values_length = past_key_value[0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = hidden_states.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||||
|
)
|
||||||
|
attention_mask = LlamaModel._prepare_decoder_attention_mask(
|
||||||
|
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = super().forward(
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
present_key_value = outputs[-1]
|
||||||
|
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||||
|
present_key_value, batch_size, seq_length_with_past
|
||||||
|
)
|
||||||
|
outputs = outputs[:-1] + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _reorder_cache_from_bloom_to_llama(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||||
|
)
|
||||||
|
value_states = value_states.view(*key_states.shape)
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _reorder_cache_from_llama_to_bloom(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||||
|
)
|
||||||
|
key_states = key_states.view(*value_states.shape)
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
return (key_states, value_states)
|
45
src/petals/models/llama/config.py
Normal file
45
src/petals/models/llama/config.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
||||||
|
from petals.models.llama.block import WrappedLlamaBlock
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedLlamaBlock
|
||||||
|
attn_class = LlamaAttention
|
||||||
|
block_prefix = "model.layers"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_key_value_groups(self):
|
||||||
|
return self.num_attention_heads // self.num_key_value_heads
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Make sure you follow the LLaMA's terms of use: "
|
||||||
|
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
dht_prefix = str(model_name_or_path)
|
||||||
|
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||||
|
if not dht_prefix.endswith("-hf"):
|
||||||
|
dht_prefix += "-hf"
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
|
||||||
|
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||||
|
config = result[0] if isinstance(result, tuple) else result
|
||||||
|
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
|
||||||
|
return result
|
151
src/petals/models/llama/model.py
Normal file
151
src/petals/models/llama/model.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.llama.config import DistributedLlamaConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
||||||
|
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.layers) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.layers = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseModelOutputWithPast:
|
||||||
|
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if not (v is None or v is False):
|
||||||
|
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
|
||||||
|
else:
|
||||||
|
hidden_states = self.layers(hidden_states)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.norm
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedLlamaConfig):
|
||||||
|
LlamaPreTrainedModel.__init__(self, config)
|
||||||
|
self.model = DistributedLlamaModel(config)
|
||||||
|
self.pretraining_tp = config.pretraining_tp
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
LlamaPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.model = DistributedLlamaModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
@ -1,44 +1,77 @@
|
|||||||
"""Code for serving bloom blocks via hivemind-server"""
|
from __future__ import annotations
|
||||||
from typing import Any, Dict, Sequence, Tuple
|
|
||||||
|
from collections import Counter
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from hivemind import BatchTensorDescriptor
|
from hivemind import BatchTensorDescriptor, TensorDescriptor
|
||||||
|
from hivemind.moe.expert_uid import ExpertUID
|
||||||
from hivemind.moe.server.module_backend import ModuleBackend
|
from hivemind.moe.server.module_backend import ModuleBackend
|
||||||
from hivemind.utils import get_logger
|
from hivemind.utils import get_logger
|
||||||
|
from tensor_parallel import TensorParallel
|
||||||
|
from tensor_parallel.tensor_parallel import PerDeviceTensors
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from petals.bloom.block import WrappedBloomBlock
|
from petals.data_structures import InferenceMetadata
|
||||||
from petals.server.memory_cache import MemoryCache
|
from petals.server.memory_cache import MemoryCache
|
||||||
from petals.server.task_pool import PrioritizedTaskPool
|
from petals.server.task_pool import PrioritizedTaskPool
|
||||||
from petals.utils.misc import is_dummy
|
from petals.utils.misc import is_dummy
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TransformerBackend(ModuleBackend):
|
class TransformerBackend(ModuleBackend):
|
||||||
"""A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
|
"""A wrapper for a transformer block that can process requests for forward, backward and inference"""
|
||||||
|
|
||||||
|
_peft_module = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
memory_cache: MemoryCache,
|
||||||
|
backend_dtype: torch.dtype,
|
||||||
|
max_chunk_size_bytes: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
import petals.utils.peft as _peft_module
|
||||||
|
|
||||||
|
self._peft_module = _peft_module
|
||||||
|
|
||||||
def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
assert isinstance(self.module, WrappedBloomBlock)
|
assert isinstance(self.module, TensorParallel)
|
||||||
|
self.config = config
|
||||||
self.memory_cache = memory_cache
|
self.memory_cache = memory_cache
|
||||||
|
self.max_chunk_size_bytes = max_chunk_size_bytes
|
||||||
|
|
||||||
for name, param in self.module.named_parameters():
|
for name, param in self.module.named_parameters():
|
||||||
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
|
||||||
for name, buf in self.module.named_buffers():
|
for name, buf in self.module.named_buffers():
|
||||||
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
|
||||||
|
|
||||||
max_batch_size = self.forward_pool.max_batch_size
|
max_batch_size = self.forward_pool.max_batch_size
|
||||||
|
device = self.module.devices[self.module.output_device_index]
|
||||||
self.inference_pool = PrioritizedTaskPool(
|
self.inference_pool = PrioritizedTaskPool(
|
||||||
self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
|
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
||||||
)
|
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
|
||||||
self.forward_pool = PrioritizedTaskPool(
|
self.forward_pool = PrioritizedTaskPool(
|
||||||
self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
|
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
||||||
)
|
)
|
||||||
self.backward_pool = PrioritizedTaskPool(
|
self.backward_pool = PrioritizedTaskPool(
|
||||||
self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
|
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert backend_dtype is not None
|
|
||||||
self.dtype = backend_dtype
|
self.dtype = backend_dtype
|
||||||
|
self.dtype_bytes = torch.finfo(self.dtype).bits // 8
|
||||||
|
self.shard_num_heads = []
|
||||||
|
for shard in self.module.module_shards:
|
||||||
|
for submodule in shard.modules():
|
||||||
|
if isinstance(submodule, config.attn_class):
|
||||||
|
self.shard_num_heads.append(submodule.num_heads)
|
||||||
|
assert len(self.shard_num_heads) == len(self.module.devices)
|
||||||
|
assert sum(self.shard_num_heads) == config.num_attention_heads
|
||||||
|
|
||||||
self.inference_schema = (
|
self.inference_schema = (
|
||||||
(
|
(
|
||||||
*self.args_schema,
|
*self.args_schema,
|
||||||
@ -48,43 +81,102 @@ class TransformerBackend(ModuleBackend):
|
|||||||
self.kwargs_schema,
|
self.kwargs_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
|
||||||
|
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
|
||||||
|
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
|
||||||
|
|
||||||
|
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
|
||||||
|
"""Create tensor descriptors for attention cache tensors used during inference_step"""
|
||||||
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
cache_tensors = []
|
||||||
|
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
|
||||||
|
num_heads //= self.config.num_key_value_groups
|
||||||
|
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
|
||||||
|
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
|
||||||
|
cache_tensors.extend((keys, values))
|
||||||
|
return cache_tensors
|
||||||
|
|
||||||
|
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
|
||||||
|
*inputs, active_adapter = inputs
|
||||||
|
with self._peft_module.using_adapter(active_adapter):
|
||||||
|
return super().forward(*inputs)
|
||||||
|
|
||||||
|
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
|
||||||
|
*inputs, active_adapter = inputs
|
||||||
|
with self._peft_module.using_adapter(active_adapter):
|
||||||
|
return super().backward(*inputs)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def inference_step(
|
def inference_step(
|
||||||
self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
hypo_ids: torch.LongTensor,
|
||||||
|
inference_info: InferenceMetadata,
|
||||||
) -> Tuple[torch.Tensor, ...]:
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
|
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
||||||
with torch.inference_mode():
|
seq_len = hidden_states.shape[1]
|
||||||
assert (
|
|
||||||
hidden_states.ndim == 3
|
|
||||||
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
||||||
cache_handle, rel_index, prefix_length = map(int, cache_metadata[0])
|
|
||||||
|
|
||||||
with self.memory_cache.use_cache(cache_handle) as cache:
|
with self.memory_cache.use_cache(
|
||||||
batch_size = cache.shape[2]
|
*inference_info.cache_handles
|
||||||
max_length = cache.shape[-1] // (head_dim * num_heads)
|
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
|
||||||
assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4
|
self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
||||||
if not is_dummy(hypo_ids):
|
|
||||||
assert hypo_ids.shape[0] == batch_size
|
|
||||||
cache[rel_index, :, :] = cache[rel_index, :, hypo_ids] # in-place reorder cache by hypo ids
|
|
||||||
key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length)
|
|
||||||
value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim)
|
|
||||||
|
|
||||||
key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
|
# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
|
||||||
value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim]
|
# reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`
|
||||||
logger.debug(
|
# is at least 4-6x less than `autograd_memory`.
|
||||||
f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}"
|
max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)
|
||||||
|
output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None
|
||||||
|
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
||||||
|
for offset in range(0, seq_len, max_chunk_length):
|
||||||
|
hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
|
||||||
|
output_hidden_states_chunk, new_kvs = self.module.forward(
|
||||||
|
hidden_states_chunk, layer_past=layer_past, use_cache=True
|
||||||
)
|
)
|
||||||
hidden_states, (new_key, new_value) = self.module.forward(
|
if seq_len > max_chunk_length:
|
||||||
hidden_states, layer_past=(key_past, value_past), use_cache=True
|
output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
|
||||||
)
|
else:
|
||||||
new_length = new_key.shape[-1]
|
output_hidden_states = output_hidden_states_chunk # saves one memcopy
|
||||||
assert new_length > prefix_length
|
layer_past = new_kvs
|
||||||
assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0]
|
|
||||||
assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length
|
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
|
||||||
new_key = new_key.view(batch_size, num_heads, head_dim, -1)
|
return (output_hidden_states,)
|
||||||
new_value = new_value.view(batch_size, num_heads, -1, head_dim)
|
|
||||||
key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
|
def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int:
|
||||||
value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
|
# We assume that attention logit matrices are the main thing that consumes memory, given that
|
||||||
return (hidden_states,)
|
# the model uses multi-query attention
|
||||||
|
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||||
|
worst_case_length = inference_info.prefix_length + seq_length
|
||||||
|
attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length
|
||||||
|
return max(1, self.max_chunk_size_bytes // attn_bytes_per_token)
|
||||||
|
|
||||||
|
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
|
||||||
|
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
|
||||||
|
if not is_dummy(hypo_ids):
|
||||||
|
for cache_tensor in cache_tensors:
|
||||||
|
cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids
|
||||||
|
|
||||||
|
def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
|
||||||
|
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
|
||||||
|
key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
|
||||||
|
for i in range(len(key_cache)):
|
||||||
|
key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]
|
||||||
|
# shape: [batch * num_kv_heads, head_dim, kv_length]
|
||||||
|
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]
|
||||||
|
# shape: [batch * num_kv_heads, kv_length, head_dim]
|
||||||
|
layer_past = tuple(chain(*zip(key_cache, value_cache)))
|
||||||
|
return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
|
||||||
|
|
||||||
|
def _update_cache_inplace(
|
||||||
|
self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
|
||||||
|
):
|
||||||
|
"""Writes new key/value tensors back into cache, works in-place"""
|
||||||
|
_batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape
|
||||||
|
for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
|
||||||
|
new_key = new_key.view(*cache_key.shape[:3], new_length)
|
||||||
|
cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
|
||||||
|
for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
|
||||||
|
new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
|
||||||
|
cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
|
||||||
|
|
||||||
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
|
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
|
||||||
return self.forward_pool, self.backward_pool, self.inference_pool
|
return self.forward_pool, self.backward_pool, self.inference_pool
|
||||||
@ -102,3 +194,40 @@ class TransformerBackend(ModuleBackend):
|
|||||||
dummy = torch.tensor([])
|
dummy = torch.tensor([])
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
p.data = dummy
|
p.data = dummy
|
||||||
|
|
||||||
|
|
||||||
|
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
|
||||||
|
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
|
||||||
|
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
|
||||||
|
first_pool = next(iter(backends.values())).inference_pool
|
||||||
|
merged_pool = PrioritizedTaskPool(
|
||||||
|
_MergedInferenceStep(backends),
|
||||||
|
max_batch_size=first_pool.max_batch_size,
|
||||||
|
device=first_pool.device,
|
||||||
|
name=f"merged_inference",
|
||||||
|
)
|
||||||
|
for backend in backends.values():
|
||||||
|
assert not backend.inference_pool.is_alive()
|
||||||
|
backend.inference_pool = merged_pool
|
||||||
|
|
||||||
|
|
||||||
|
class _MergedInferenceStep:
|
||||||
|
def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
|
||||||
|
self.backends = backends
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
hypo_ids: torch.LongTensor,
|
||||||
|
inference_infos: Sequence[InferenceMetadata],
|
||||||
|
*optional_prompts: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
assert len(inference_infos) == len(
|
||||||
|
optional_prompts
|
||||||
|
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
|
||||||
|
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
|
||||||
|
if optional_prompt is not None:
|
||||||
|
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
|
||||||
|
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
|
||||||
|
return (hidden_states,)
|
||||||
|
211
src/petals/server/block_functions.py
Normal file
211
src/petals/server/block_functions.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import AsyncIterator, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
|
||||||
|
from hivemind.moe.expert_uid import ExpertUID
|
||||||
|
from hivemind.proto import runtime_pb2
|
||||||
|
from hivemind.utils.nested import nested_flatten
|
||||||
|
|
||||||
|
from petals.data_structures import InferenceMetadata
|
||||||
|
from petals.server.backend import TransformerBackend
|
||||||
|
from petals.server.memory_cache import Handle
|
||||||
|
from petals.server.task_pool import PrioritizedTaskPool
|
||||||
|
from petals.server.task_prioritizer import TaskPrioritizerBase
|
||||||
|
from petals.utils.convert_block import QuantType
|
||||||
|
from petals.utils.misc import DUMMY, is_dummy
|
||||||
|
|
||||||
|
# We prioritize short inference requests and make them use a *merged* inference pool,
|
||||||
|
# so they are processed without interruptions and extra overheads
|
||||||
|
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
|
||||||
|
MAX_SHORT_INFERENCE_TOKENS = 128
|
||||||
|
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
|
||||||
|
|
||||||
|
|
||||||
|
async def run_rpc_forward(
|
||||||
|
*flat_tensors: torch.Tensor,
|
||||||
|
requested_backends: Sequence[TransformerBackend],
|
||||||
|
active_adapter: str = "",
|
||||||
|
prioritizer: TaskPrioritizerBase,
|
||||||
|
points: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
||||||
|
|
||||||
|
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
||||||
|
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
||||||
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
||||||
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
||||||
|
"""
|
||||||
|
hidden_states, prompts = flat_tensors
|
||||||
|
dtype = requested_backends[0].dtype
|
||||||
|
# check parse input tensors and cast dtypes
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
assert hidden_states.ndim == 3
|
||||||
|
if prompts is None or is_dummy(prompts):
|
||||||
|
prompts = [DUMMY] * len(requested_backends)
|
||||||
|
else:
|
||||||
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||||
|
|
||||||
|
# Run a chain of requested backends
|
||||||
|
for backend, prompt in zip(requested_backends, prompts):
|
||||||
|
if not is_dummy(prompt):
|
||||||
|
hidden_states[:, : prompt.shape[1]] += prompt
|
||||||
|
|
||||||
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
||||||
|
)
|
||||||
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
||||||
|
hidden_states,
|
||||||
|
active_adapter,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
assert isinstance(hidden_states, torch.Tensor)
|
||||||
|
assert (
|
||||||
|
hidden_states.ndim == 3
|
||||||
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
async def run_rpc_backward(
|
||||||
|
*flat_tensors: torch.Tensor,
|
||||||
|
requested_backends: Sequence[TransformerBackend],
|
||||||
|
active_adapter: str = "",
|
||||||
|
prioritizer: TaskPrioritizerBase,
|
||||||
|
points: int = 0,
|
||||||
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||||
|
inputs, grad_outputs, prompts = flat_tensors
|
||||||
|
# Cast inputs & grad outputs to backend dtype
|
||||||
|
inputs = inputs.to(requested_backends[0].dtype)
|
||||||
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
||||||
|
|
||||||
|
if prompts is None or is_dummy(prompts):
|
||||||
|
prompts = [DUMMY] * len(requested_backends)
|
||||||
|
else:
|
||||||
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||||
|
|
||||||
|
# Run a forward chain to collect intermediate inputs
|
||||||
|
# Note that we do not forward for the last module since we do not need its output
|
||||||
|
inter_inputs = []
|
||||||
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
||||||
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
||||||
|
if not is_dummy(prompt):
|
||||||
|
inputs[:, : prompt.shape[1]] += prompt
|
||||||
|
inter_inputs.append(inputs)
|
||||||
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
||||||
|
)
|
||||||
|
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
||||||
|
|
||||||
|
assert isinstance(inputs, torch.Tensor)
|
||||||
|
|
||||||
|
if not is_dummy(prompts[-1]):
|
||||||
|
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
||||||
|
inter_inputs.append(inputs)
|
||||||
|
|
||||||
|
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
||||||
|
grad_prompts_reversed = []
|
||||||
|
# Run a chain of requested backends
|
||||||
|
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
||||||
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
||||||
|
)
|
||||||
|
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
||||||
|
|
||||||
|
assert isinstance(grad_outputs, torch.Tensor)
|
||||||
|
if not is_dummy(prompt):
|
||||||
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
||||||
|
|
||||||
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
||||||
|
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
||||||
|
|
||||||
|
|
||||||
|
async def iterate_rpc_inference(
|
||||||
|
requested_uids: Sequence[ExpertUID],
|
||||||
|
requested_backends: Sequence[TransformerBackend],
|
||||||
|
active_adapter: Optional[str],
|
||||||
|
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
|
||||||
|
cache_handles: Sequence[Sequence[Handle]],
|
||||||
|
*,
|
||||||
|
max_length: int,
|
||||||
|
prioritizer: TaskPrioritizerBase,
|
||||||
|
points: int,
|
||||||
|
quant_type: QuantType,
|
||||||
|
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
|
||||||
|
assert len(cache_handles) == len(requested_backends)
|
||||||
|
|
||||||
|
prefix_length = 0
|
||||||
|
point_per_piece = points / max_length if max_length > 0 else 0.0
|
||||||
|
|
||||||
|
async for request, step_metadata in input_iterator:
|
||||||
|
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
|
||||||
|
batch_size, length_increment, _ = hidden_states.shape
|
||||||
|
|
||||||
|
# Cast inputs to backend dtype
|
||||||
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
||||||
|
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
||||||
|
|
||||||
|
# parse deep prompts (optional argument)
|
||||||
|
has_prompts = prompts is not None and not is_dummy(prompts)
|
||||||
|
if not has_prompts:
|
||||||
|
prompts = [None] * len(requested_backends)
|
||||||
|
else:
|
||||||
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||||
|
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
|
||||||
|
|
||||||
|
if not (len(requested_backends) == len(prompts)):
|
||||||
|
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
||||||
|
|
||||||
|
if prefix_length + length_increment > max_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
||||||
|
f" exceeds pre-allocated maximum {max_length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
|
||||||
|
can_merge_pools = batch_size * length_increment <= merge_max_tokens
|
||||||
|
priority = prioritizer.prioritize(
|
||||||
|
hidden_states,
|
||||||
|
hypo_ids,
|
||||||
|
points=point_per_piece,
|
||||||
|
requested_uids=requested_uids,
|
||||||
|
type="short_inference" if can_merge_pools else "inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
|
||||||
|
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
|
||||||
|
if hidden_states.numel() > 0:
|
||||||
|
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
|
||||||
|
if can_merge_pools:
|
||||||
|
inference_infos = tuple(
|
||||||
|
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
||||||
|
for uid, handles in zip(requested_uids, cache_handles)
|
||||||
|
)
|
||||||
|
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
||||||
|
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
||||||
|
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
||||||
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
||||||
|
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
||||||
|
)
|
||||||
|
|
||||||
|
# serialize and send last layer outputs
|
||||||
|
output_tensors = [
|
||||||
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
||||||
|
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
||||||
|
]
|
||||||
|
can_push = not has_prompts
|
||||||
|
yield output_tensors, can_push
|
||||||
|
|
||||||
|
# prepare for next step
|
||||||
|
prefix_length += length_increment
|
@ -8,7 +8,7 @@ from petals.data_structures import RemoteModuleInfo, ServerState
|
|||||||
|
|
||||||
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
|
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -16,6 +16,7 @@ class Span:
|
|||||||
start: int
|
start: int
|
||||||
end: int
|
end: int
|
||||||
throughput: float
|
throughput: float
|
||||||
|
state: ServerState
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def length(self):
|
def length(self):
|
||||||
@ -43,7 +44,7 @@ def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[
|
|||||||
spans[peer_id].start = min(spans[peer_id].start, block)
|
spans[peer_id].start = min(spans[peer_id].start, block)
|
||||||
spans[peer_id].end = max(spans[peer_id].start, block + 1)
|
spans[peer_id].end = max(spans[peer_id].start, block + 1)
|
||||||
else:
|
else:
|
||||||
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
|
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
|
||||||
|
|
||||||
throughputs[block] += server.throughput
|
throughputs[block] += server.throughput
|
||||||
|
|
||||||
@ -79,6 +80,9 @@ def should_choose_other_blocks(
|
|||||||
# Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer
|
# Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer
|
||||||
# the previous server position in case of other things being almost equal.
|
# the previous server position in case of other things being almost equal.
|
||||||
|
|
||||||
|
if initial_throughput > eps and throughputs.min() <= 0:
|
||||||
|
return False # Switching blocks would make the swarm disjoint
|
||||||
|
|
||||||
new_start = _choose_best_start(throughputs, local_span.length)
|
new_start = _choose_best_start(throughputs, local_span.length)
|
||||||
if local_span.start == new_start:
|
if local_span.start == new_start:
|
||||||
return False # This server is on its best place already
|
return False # This server is on its best place already
|
||||||
|
@ -2,47 +2,50 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from transformers import BloomConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from petals.bloom.block import WrappedBloomBlock
|
from petals.utils.convert_block import QuantType
|
||||||
|
|
||||||
|
|
||||||
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
|
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
|
||||||
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
|
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
|
||||||
|
if dtype not in ("auto", None):
|
||||||
if dtype == "auto" or dtype is None:
|
return dtype
|
||||||
dtype = config.torch_dtype
|
if config.torch_dtype not in ("auto", None, torch.float32):
|
||||||
if dtype == "auto" or dtype is None:
|
# If config specifies float32, we override it to the default dtype below
|
||||||
dtype = torch.float32
|
return config.torch_dtype
|
||||||
return dtype
|
return torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
def get_block_size(
|
def get_block_size(
|
||||||
config: BloomConfig,
|
config: PretrainedConfig,
|
||||||
location: str,
|
location: str,
|
||||||
*,
|
*,
|
||||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
load_in_8bit: Optional[bool] = None,
|
quant_type: QuantType = QuantType.NONE,
|
||||||
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
|
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
|
||||||
) -> int:
|
) -> int:
|
||||||
if location == "memory":
|
if location == "memory":
|
||||||
assert (
|
assert (
|
||||||
dtype is not None and load_in_8bit is not None
|
dtype is not None and quant_type is not None
|
||||||
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
|
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights(include_buffers=True):
|
||||||
block = WrappedBloomBlock(config)
|
block = config.block_class(config)
|
||||||
n_params = sum(param.numel() for param in block.parameters())
|
n_params = sum(param.numel() for param in block.parameters())
|
||||||
|
|
||||||
if location == "memory" and load_in_8bit:
|
|
||||||
# Note: We may need a larger eps here for models of size < 1B
|
|
||||||
return n_params * (1 + eps)
|
|
||||||
|
|
||||||
if location == "memory":
|
if location == "memory":
|
||||||
dtype = resolve_block_dtype(config, dtype)
|
if quant_type == QuantType.NONE:
|
||||||
|
dtype = resolve_block_dtype(config, dtype)
|
||||||
|
bytes_per_value = torch.finfo(dtype).bits // 8
|
||||||
|
elif quant_type == QuantType.INT8:
|
||||||
|
bytes_per_value = 1
|
||||||
|
elif quant_type == QuantType.NF4:
|
||||||
|
bytes_per_value = 4.25 / 8 # Bitness of NF4 with this config (measured empirically)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quant_type={quant_type}")
|
||||||
elif location == "disk":
|
elif location == "disk":
|
||||||
dtype = resolve_block_dtype(config, "auto")
|
dtype = resolve_block_dtype(config, "auto")
|
||||||
else:
|
bytes_per_value = torch.finfo(dtype).bits // 8
|
||||||
raise ValueError('get_block_size() expects location to be "memory" or "disk"')
|
|
||||||
|
|
||||||
return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps))
|
return round(n_params * bytes_per_value * (1 + eps))
|
||||||
|
177
src/petals/server/from_pretrained.py
Normal file
177
src/petals/server/from_pretrained.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
||||||
|
If necessary, one can rewrite this to implement a different behavior, such as:
|
||||||
|
- loading files from a local data source (e.g. S3)
|
||||||
|
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
||||||
|
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
||||||
|
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.utils import get_file_from_repo
|
||||||
|
|
||||||
|
from petals.constants import DTYPE_MAP
|
||||||
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||||
|
from petals.utils.hf_auth import always_needs_auth
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretrained_block(
|
||||||
|
model_name: str,
|
||||||
|
block_index: int,
|
||||||
|
*,
|
||||||
|
config: Optional[PretrainedConfig] = None,
|
||||||
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
if config is None:
|
||||||
|
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = DEFAULT_CACHE_DIR
|
||||||
|
|
||||||
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
||||||
|
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
block = config.block_class(config)
|
||||||
|
|
||||||
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
||||||
|
state_dict = _load_state_dict_from_repo(
|
||||||
|
model_name,
|
||||||
|
block_prefix,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dummy load, check that keys match
|
||||||
|
report = block.load_state_dict(state_dict, strict=True)
|
||||||
|
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
||||||
|
|
||||||
|
for param_name, _ in block.named_parameters():
|
||||||
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
||||||
|
param = state_dict[param_name]
|
||||||
|
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||||
|
param = param.to(torch_dtype)
|
||||||
|
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {model_name} block {block_index}, {report}")
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
|
StateDict = Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_from_repo(
|
||||||
|
model_name: str,
|
||||||
|
block_prefix: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
) -> StateDict:
|
||||||
|
if always_needs_auth(model_name) and token is None:
|
||||||
|
token = True
|
||||||
|
|
||||||
|
index_file = get_file_from_repo(
|
||||||
|
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
if index_file is not None: # Sharded model
|
||||||
|
with open(index_file) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
filenames = {
|
||||||
|
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
|
||||||
|
}
|
||||||
|
if not filenames:
|
||||||
|
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
|
||||||
|
else: # Non-sharded model
|
||||||
|
filenames = {"pytorch_model.bin"}
|
||||||
|
logger.debug(f"Loading {block_prefix}* from {filenames}")
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
for filename in filenames:
|
||||||
|
shard_state_dict = _load_state_dict_from_file(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
|
)
|
||||||
|
shard_state_dict = {
|
||||||
|
param_name[len(block_prefix) :]: param
|
||||||
|
for param_name, param in shard_state_dict.items()
|
||||||
|
if param_name.startswith(block_prefix)
|
||||||
|
} # Remove unused parameters from memory
|
||||||
|
state_dict.update(shard_state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_from_file(
|
||||||
|
model_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
delay: float = 30,
|
||||||
|
) -> StateDict:
|
||||||
|
# First, try to find the weights locally
|
||||||
|
try:
|
||||||
|
with allow_cache_reads(cache_dir):
|
||||||
|
path = get_file_from_repo(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
if path is not None:
|
||||||
|
return torch.load(path, map_location="cpu")
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
|
||||||
|
|
||||||
|
# If not found, ensure that we have enough disk space to download them (maybe remove something)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with allow_cache_writes(cache_dir):
|
||||||
|
url = hf_hub_url(model_name, filename, revision=revision)
|
||||||
|
file_size = get_hf_file_metadata(url, token=token).size
|
||||||
|
if file_size is not None:
|
||||||
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
|
||||||
|
|
||||||
|
path = get_file_from_repo(
|
||||||
|
model_name,
|
||||||
|
filename,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
if path is None:
|
||||||
|
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
|
||||||
|
return torch.load(path, map_location="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
||||||
|
time.sleep(delay)
|
@ -1,6 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
import multiprocessing as mp
|
||||||
|
import sys
|
||||||
|
from enum import Enum
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from async_timeout import timeout
|
from async_timeout import timeout
|
||||||
@ -8,10 +14,11 @@ from hivemind import (
|
|||||||
DHT,
|
DHT,
|
||||||
MSGPackSerializer,
|
MSGPackSerializer,
|
||||||
P2PContext,
|
P2PContext,
|
||||||
TensorDescriptor,
|
PeerID,
|
||||||
deserialize_tensor_stream,
|
deserialize_tensor_stream,
|
||||||
deserialize_torch_tensor,
|
deserialize_torch_tensor,
|
||||||
nested_flatten,
|
nested_flatten,
|
||||||
|
nested_pack,
|
||||||
serialize_torch_tensor,
|
serialize_torch_tensor,
|
||||||
)
|
)
|
||||||
from hivemind.moe.server.connection_handler import ConnectionHandler
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
||||||
@ -21,13 +28,29 @@ from hivemind.utils.asyncio import amap_in_executor, anext
|
|||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
from hivemind.utils.streaming import split_for_streaming
|
from hivemind.utils.streaming import split_for_streaming
|
||||||
|
|
||||||
from petals.data_structures import CHAIN_DELIMITER, ModuleUID
|
import petals
|
||||||
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
|
||||||
from petals.server.backend import TransformerBackend
|
from petals.server.backend import TransformerBackend
|
||||||
from petals.server.task_pool import PrioritizedTaskPool
|
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
|
||||||
|
from petals.server.memory_cache import Handle
|
||||||
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
||||||
from petals.utils.misc import DUMMY, is_dummy
|
from petals.utils.convert_block import QuantType
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Fix pickling protobufs, see https://stackoverflow.com/a/74873028
|
||||||
|
sys.modules["runtime_pb2"] = runtime_pb2
|
||||||
|
|
||||||
|
|
||||||
|
CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
|
||||||
|
|
||||||
|
|
||||||
|
class Event(Enum):
|
||||||
|
NEW_SESSION = 0
|
||||||
|
END_SESSION = 1
|
||||||
|
PUSH = 2
|
||||||
|
SHUTDOWN = 3
|
||||||
|
|
||||||
|
|
||||||
class TransformerConnectionHandler(ConnectionHandler):
|
class TransformerConnectionHandler(ConnectionHandler):
|
||||||
@ -40,23 +63,45 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
dht: DHT,
|
dht: DHT,
|
||||||
module_backends: Dict[str, TransformerBackend],
|
module_backends: Dict[str, TransformerBackend],
|
||||||
*,
|
*,
|
||||||
|
adapters: Optional[Sequence[str]],
|
||||||
|
dht_prefix: str,
|
||||||
|
handler_event_queues: Sequence[mp.Queue],
|
||||||
|
handler_index: int,
|
||||||
inference_max_length: int,
|
inference_max_length: int,
|
||||||
request_timeout: float,
|
request_timeout: float,
|
||||||
session_timeout: float,
|
session_timeout: float,
|
||||||
step_timeout: float,
|
step_timeout: float,
|
||||||
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
|
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
|
||||||
|
quant_type: QuantType,
|
||||||
):
|
):
|
||||||
super().__init__(dht, module_backends)
|
super().__init__(dht, module_backends)
|
||||||
for module_backend in self.module_backends.values():
|
for module_backend in self.module_backends.values():
|
||||||
assert isinstance(module_backend, TransformerBackend)
|
assert isinstance(module_backend, TransformerBackend)
|
||||||
|
self.dht_prefix = dht_prefix
|
||||||
|
self.adapters = adapters
|
||||||
|
self._handler_event_queues = handler_event_queues
|
||||||
|
self._handler_index = handler_index
|
||||||
|
self._own_event_queue = handler_event_queues[handler_index]
|
||||||
|
self._listener_task: Optional[asyncio.Task] = None
|
||||||
|
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||||
|
self._session_handlers: Dict[str, int] = {}
|
||||||
|
|
||||||
self.inference_max_length = inference_max_length
|
self.inference_max_length = inference_max_length
|
||||||
self.request_timeout = request_timeout
|
self.request_timeout = request_timeout
|
||||||
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
||||||
self._prioritizer = task_prioritizer
|
self._prioritizer = task_prioritizer
|
||||||
|
self.quant_type = quant_type
|
||||||
|
|
||||||
|
async def add_p2p_handlers(self, *args, **kwargs) -> None:
|
||||||
|
if self._listener_task is None:
|
||||||
|
# Start listening to our own event queue before we accept any requests
|
||||||
|
self._listener_task = asyncio.create_task(self._listen_to_event_queue())
|
||||||
|
await super().add_p2p_handlers(*args, **kwargs)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if self.is_alive():
|
if self.is_alive():
|
||||||
self._outer_pipe.send("_shutdown")
|
self._outer_pipe.send("_shutdown")
|
||||||
|
self._own_event_queue.put((Event.SHUTDOWN, None, None))
|
||||||
self.join(self.shutdown_timeout)
|
self.join(self.shutdown_timeout)
|
||||||
if self.is_alive():
|
if self.is_alive():
|
||||||
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
||||||
@ -89,9 +134,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
self,
|
self,
|
||||||
requests: AsyncIterator[runtime_pb2.ExpertRequest],
|
requests: AsyncIterator[runtime_pb2.ExpertRequest],
|
||||||
context: P2PContext,
|
context: P2PContext,
|
||||||
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
||||||
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
||||||
|
|
||||||
async with timeout(self.session_timeout):
|
async with timeout(self.session_timeout):
|
||||||
try:
|
try:
|
||||||
request = await asyncio.wait_for(anext(requests), self.step_timeout)
|
request = await asyncio.wait_for(anext(requests), self.step_timeout)
|
||||||
@ -106,7 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||||
max_length = metadata.get("max_length")
|
max_length = metadata.get("max_length")
|
||||||
points = metadata.get("points", 0)
|
points = metadata.get("points", 0)
|
||||||
|
session_id = metadata.get("session_id")
|
||||||
if not requested_uids:
|
if not requested_uids:
|
||||||
raise ValueError("User must specify at least one block for inference, but got none")
|
raise ValueError("User must specify at least one block for inference, but got none")
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@ -120,92 +164,187 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
|
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
|
||||||
)
|
)
|
||||||
|
|
||||||
point_per_piece = points / max_length if max_length > 0 else 0.0
|
|
||||||
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
||||||
|
|
||||||
cache_metadata = torch.tensor(
|
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
|
||||||
[[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64
|
background_tasks = set()
|
||||||
) # [cache_handle, rel_index, prefix_length]
|
async for output_tensors, can_push in iterate_rpc_inference(
|
||||||
prefix_length = 0
|
requested_uids=requested_uids,
|
||||||
|
requested_backends=requested_backends,
|
||||||
|
active_adapter=self._get_active_adapter(metadata),
|
||||||
|
input_iterator=self._iterate_inference_steps(
|
||||||
|
request, requests, session_id, requested_uids, context
|
||||||
|
),
|
||||||
|
cache_handles=cache_handles,
|
||||||
|
max_length=max_length,
|
||||||
|
prioritizer=self._prioritizer,
|
||||||
|
points=points,
|
||||||
|
quant_type=self.quant_type,
|
||||||
|
):
|
||||||
|
if can_push:
|
||||||
|
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
|
||||||
|
background_tasks.add(task) # Keep reference until it is done to save it from GC
|
||||||
|
task.add_done_callback(background_tasks.discard)
|
||||||
|
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
|
||||||
|
|
||||||
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle:
|
|
||||||
while request.tensors: # iterate while user is willing to supply tensors
|
|
||||||
hidden_states, prompts, hypo_ids = [
|
|
||||||
deserialize_torch_tensor(tensor) for tensor in request.tensors
|
|
||||||
]
|
|
||||||
|
|
||||||
# Cast inputs to backend dtype
|
|
||||||
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
||||||
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
|
||||||
|
|
||||||
# parse deep prompts (optional argument)
|
|
||||||
if prompts is None or is_dummy(prompts) or is_dummy(prompts):
|
|
||||||
prompts = [DUMMY] * len(requested_backends)
|
|
||||||
else:
|
|
||||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
||||||
|
|
||||||
if not (len(requested_backends) == len(prompts)):
|
|
||||||
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
|
||||||
|
|
||||||
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
|
|
||||||
if prefix_length + length_increment > max_length:
|
|
||||||
raise ValueError(
|
|
||||||
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
|
||||||
f" exceeds pre-allocated maximum {max_length}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# run request tensors through all requested modules, update caches
|
|
||||||
for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)):
|
|
||||||
if not is_dummy(prompt):
|
|
||||||
hidden_states[:, : prompt.shape[1]] += prompt
|
|
||||||
if hidden_states.numel() == 0:
|
|
||||||
continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
|
|
||||||
# when user wants to pre-allocate cache or check that server *can* allocate that cache
|
|
||||||
|
|
||||||
cache_metadata[:] = torch.tensor(
|
|
||||||
[cache_handle, rel_index, prefix_length], dtype=torch.int64
|
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
hidden_states, torch.Tensor
|
|
||||||
), f"hidden states must be tensor, got {type(hidden_states)}"
|
|
||||||
assert (
|
|
||||||
hidden_states.ndim == 3
|
|
||||||
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
||||||
assert isinstance(
|
|
||||||
backend.inference_pool, PrioritizedTaskPool
|
|
||||||
), "petals support only prioritized pools"
|
|
||||||
priority = self._prioritizer.prioritize(
|
|
||||||
cache_metadata,
|
|
||||||
hidden_states,
|
|
||||||
hypo_ids,
|
|
||||||
points=point_per_piece / len(requested_backends),
|
|
||||||
backend=backend,
|
|
||||||
type="inference",
|
|
||||||
)
|
|
||||||
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
||||||
hidden_states, hypo_ids, cache_metadata, priority=priority
|
|
||||||
)
|
|
||||||
|
|
||||||
# serialize and send last layer outputs
|
|
||||||
yield runtime_pb2.ExpertResponse(
|
|
||||||
tensors=[
|
|
||||||
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
||||||
for result, proto in zip(
|
|
||||||
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepare for next step
|
|
||||||
prefix_length += hidden_states.shape[1]
|
|
||||||
try:
|
|
||||||
request = await asyncio.wait_for(anext(requests), self.step_timeout)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
|
|
||||||
return
|
|
||||||
finally:
|
finally:
|
||||||
self._log_request("rpc_inference.close", requested_uids, context)
|
self._log_request("rpc_inference.close", requested_uids, context)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _managed_session(self, session_id: str):
|
||||||
|
assert session_id not in self._session_queues, f"session id {session_id} is not unique"
|
||||||
|
try:
|
||||||
|
self._session_queues[session_id] = asyncio.Queue()
|
||||||
|
self._session_handlers[session_id] = self._handler_index
|
||||||
|
for other_index, other_queue in enumerate(self._handler_event_queues):
|
||||||
|
if other_index != self._handler_index:
|
||||||
|
other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._session_queues.pop(session_id).put_nowait(None) # put None so that the get task will not hang
|
||||||
|
del self._session_handlers[session_id]
|
||||||
|
for other_index, other_queue in enumerate(self._handler_event_queues):
|
||||||
|
if other_index != self._handler_index:
|
||||||
|
other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index))
|
||||||
|
|
||||||
|
def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest):
|
||||||
|
handler_index = self._session_handlers.get(session_id)
|
||||||
|
if handler_index is None:
|
||||||
|
logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}")
|
||||||
|
elif handler_index == self._handler_index:
|
||||||
|
self._session_queues[session_id].put_nowait(request)
|
||||||
|
else:
|
||||||
|
self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request))
|
||||||
|
|
||||||
|
async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]:
|
||||||
|
assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler"
|
||||||
|
return await self._session_queues[session_id].get()
|
||||||
|
|
||||||
|
async def _listen_to_event_queue(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get)
|
||||||
|
if event == Event.SHUTDOWN:
|
||||||
|
break
|
||||||
|
elif event == Event.NEW_SESSION:
|
||||||
|
self._session_handlers[session_id] = payload # index of the handler that owns that session
|
||||||
|
elif event == Event.END_SESSION:
|
||||||
|
self._session_handlers.pop(session_id, None)
|
||||||
|
elif event == Event.PUSH:
|
||||||
|
maybe_session_queue = self._session_queues.get(session_id)
|
||||||
|
if maybe_session_queue is not None:
|
||||||
|
maybe_session_queue.put_nowait(payload)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected event: {event}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
|
||||||
|
async def _iterate_inference_steps(
|
||||||
|
self,
|
||||||
|
first_request: runtime_pb2.ExpertRequest,
|
||||||
|
requests: AsyncIterator[runtime_pb2.ExpertRequest],
|
||||||
|
session_id: Optional[str],
|
||||||
|
requested_uids: Sequence[str],
|
||||||
|
context: P2PContext,
|
||||||
|
) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
|
||||||
|
processed_step_ids = set()
|
||||||
|
n_pushes = n_late_pushes = 0
|
||||||
|
request = first_request
|
||||||
|
anext_task = get_push_task = None
|
||||||
|
try:
|
||||||
|
with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext():
|
||||||
|
while request.tensors: # iterate while user is willing to supply tensors
|
||||||
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
||||||
|
step_id = metadata.get("step_id")
|
||||||
|
|
||||||
|
pushed = metadata.get("pushed")
|
||||||
|
if pushed:
|
||||||
|
n_pushes += 1
|
||||||
|
self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push")
|
||||||
|
|
||||||
|
if step_id is None or step_id not in processed_step_ids:
|
||||||
|
yield request, metadata
|
||||||
|
if step_id is not None:
|
||||||
|
processed_step_ids.add(step_id)
|
||||||
|
elif pushed:
|
||||||
|
n_late_pushes += 1
|
||||||
|
self._log_request(
|
||||||
|
"rpc_inference.push",
|
||||||
|
requested_uids,
|
||||||
|
context,
|
||||||
|
warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the next request, coming either from the `requests` iterator or `push_queue`
|
||||||
|
if anext_task is None:
|
||||||
|
anext_task = asyncio.create_task(anext(requests))
|
||||||
|
if get_push_task is None:
|
||||||
|
if session_id is not None:
|
||||||
|
get_push_task = asyncio.create_task(self._get_from_session_queue(session_id))
|
||||||
|
else:
|
||||||
|
get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task
|
||||||
|
done, _ = await asyncio.wait(
|
||||||
|
[anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
if anext_task in done:
|
||||||
|
request = await anext_task
|
||||||
|
anext_task = None
|
||||||
|
elif get_push_task in done:
|
||||||
|
request = await get_push_task
|
||||||
|
get_push_task = None
|
||||||
|
else:
|
||||||
|
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
|
||||||
|
anext_task.cancel()
|
||||||
|
get_push_task.cancel()
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
||||||
|
"""Directly push activation tensors from one server to another"""
|
||||||
|
|
||||||
|
requested_uids = self._check_uids(request.uid)
|
||||||
|
metadata = MSGPackSerializer.loads(request.metadata)
|
||||||
|
session_id = metadata["session_id"]
|
||||||
|
self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
|
||||||
|
self._put_into_session_queue(session_id, request)
|
||||||
|
return runtime_pb2.ExpertResponse()
|
||||||
|
|
||||||
|
async def _push_outputs(
|
||||||
|
self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
next_servers = metadata.get("next_servers")
|
||||||
|
if not next_servers:
|
||||||
|
return
|
||||||
|
|
||||||
|
next_peer_id, next_session_id, next_start, next_end = next_servers[0]
|
||||||
|
next_peer_id = PeerID.from_base58(next_peer_id)
|
||||||
|
next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end))
|
||||||
|
|
||||||
|
# Sending hidden states serialized with output_schema to avoid double serialization
|
||||||
|
next_tensors = [serialized_outputs] + request.tensors[1:]
|
||||||
|
next_metadata = metadata.copy()
|
||||||
|
next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True)
|
||||||
|
|
||||||
|
stub = self.get_stub(self._p2p, next_peer_id)
|
||||||
|
await stub.rpc_push(
|
||||||
|
runtime_pb2.ExpertRequest(
|
||||||
|
uid=next_uid,
|
||||||
|
tensors=next_tensors,
|
||||||
|
metadata=MSGPackSerializer.dumps(next_metadata),
|
||||||
|
),
|
||||||
|
timeout=self.request_timeout,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
||||||
async with timeout(self.request_timeout):
|
async with timeout(self.request_timeout):
|
||||||
# Parse request and prepare backends
|
# Parse request and prepare backends
|
||||||
@ -215,13 +354,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
|
|
||||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||||
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
||||||
|
active_adapter = self._get_active_adapter(metadata)
|
||||||
points = metadata.get("points", 0)
|
points = metadata.get("points", 0)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
points, (float, int)
|
points, (float, int)
|
||||||
), f"rpc_forward should have number of points as number or None, got {points}"
|
), f"rpc_forward should have number of points as number or None, got {points}"
|
||||||
|
|
||||||
hidden_states = await _rpc_forward(
|
hidden_states = await run_rpc_forward(
|
||||||
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
*flat_inputs,
|
||||||
|
requested_backends=requested_backends,
|
||||||
|
prioritizer=self._prioritizer,
|
||||||
|
active_adapter=active_adapter,
|
||||||
|
points=points,
|
||||||
)
|
)
|
||||||
return runtime_pb2.ExpertResponse(
|
return runtime_pb2.ExpertResponse(
|
||||||
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
||||||
@ -237,13 +381,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
self._log_request("rpc_forward_stream", requested_uids, context)
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
||||||
|
|
||||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||||
|
active_adapter = self._get_active_adapter(metadata)
|
||||||
points = metadata.get("points", 0)
|
points = metadata.get("points", 0)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
points, (float, int)
|
points, (float, int)
|
||||||
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
||||||
|
|
||||||
hidden_states = await _rpc_forward(
|
hidden_states = await run_rpc_forward(
|
||||||
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
*flat_inputs,
|
||||||
|
requested_backends=requested_backends,
|
||||||
|
prioritizer=self._prioritizer,
|
||||||
|
active_adapter=active_adapter,
|
||||||
|
points=points,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Split the serialized_output for streaming and respond to client
|
# Split the serialized_output for streaming and respond to client
|
||||||
@ -283,13 +432,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
|
|
||||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||||
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
||||||
|
active_adapter = self._get_active_adapter(metadata)
|
||||||
points = metadata.get("points", 0)
|
points = metadata.get("points", 0)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
points, (float, int)
|
points, (float, int)
|
||||||
), f"rpc_backward should have number of points as number or None, got {points}"
|
), f"rpc_backward should have number of points as number or None, got {points}"
|
||||||
|
|
||||||
grads = await _rpc_backward(
|
grads = await run_rpc_backward(
|
||||||
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
*flat_tensors,
|
||||||
|
requested_backends=requested_backends,
|
||||||
|
prioritizer=self._prioritizer,
|
||||||
|
active_adapter=active_adapter,
|
||||||
|
points=points,
|
||||||
)
|
)
|
||||||
|
|
||||||
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
||||||
@ -303,19 +457,30 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
self._log_request("rpc_backward_stream", requested_uids, context)
|
self._log_request("rpc_backward_stream", requested_uids, context)
|
||||||
|
|
||||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||||
|
active_adapter = self._get_active_adapter(metadata)
|
||||||
points = metadata.get("points", 0)
|
points = metadata.get("points", 0)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
points, (float, int)
|
points, (float, int)
|
||||||
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
||||||
|
|
||||||
grads = await _rpc_backward(
|
grads = await run_rpc_backward(
|
||||||
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
*flat_tensors,
|
||||||
|
requested_backends=requested_backends,
|
||||||
|
prioritizer=self._prioritizer,
|
||||||
|
active_adapter=active_adapter,
|
||||||
|
points=points,
|
||||||
)
|
)
|
||||||
# Split the serialized_grad_inputs for streaming and respond
|
# Split the serialized_grad_inputs for streaming and respond
|
||||||
for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
||||||
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
||||||
yield runtime_pb2.ExpertResponse(tensors=[part])
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
||||||
|
|
||||||
|
def _get_active_adapter(self, metadata: dict) -> str:
|
||||||
|
active_adapter = metadata.get("active_adapter", "")
|
||||||
|
if active_adapter and (active_adapter not in self.adapters):
|
||||||
|
raise KeyError(f"adapter {active_adapter} not found")
|
||||||
|
return active_adapter
|
||||||
|
|
||||||
def _serialize_grads(
|
def _serialize_grads(
|
||||||
self,
|
self,
|
||||||
grads: Sequence[torch.Tensor],
|
grads: Sequence[torch.Tensor],
|
||||||
@ -355,31 +520,23 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _allocate_cache(
|
async def _allocate_cache(
|
||||||
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
|
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
|
||||||
) -> Sequence[int]:
|
) -> Sequence[Sequence[Handle]]:
|
||||||
"""Allocate memory cache for all transformer blocks, return cache handle"""
|
"""
|
||||||
|
Allocate memory cache for all transformer blocks, return cache handle
|
||||||
n_blocks = len(backends)
|
:returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
|
||||||
backend = backends[0]
|
"""
|
||||||
n_heads = backend.module.self_attention.num_heads
|
descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
|
||||||
head_dim = backend.module.self_attention.head_dim
|
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
|
||||||
descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype)
|
yield nested_pack(handles, descriptors)
|
||||||
alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
||||||
|
|
||||||
gib = 1024**3
|
|
||||||
cur_size = backend.memory_cache.current_size_bytes
|
|
||||||
max_size = backend.memory_cache.max_size_bytes
|
|
||||||
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
|
||||||
logger.info(
|
|
||||||
f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), "
|
|
||||||
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
||||||
)
|
|
||||||
|
|
||||||
async with backend.memory_cache.allocate_cache(descr) as handle:
|
|
||||||
logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
|
|
||||||
yield handle
|
|
||||||
|
|
||||||
def _log_request(
|
def _log_request(
|
||||||
self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
|
self,
|
||||||
|
method: str,
|
||||||
|
uids: Optional[Sequence[ModuleUID]],
|
||||||
|
context: P2PContext,
|
||||||
|
*,
|
||||||
|
debug: Optional[str] = None,
|
||||||
|
warning: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if uids is not None:
|
if uids is not None:
|
||||||
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
|
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
|
||||||
@ -391,107 +548,28 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
friendly_remote_id = "..." + str(context.remote_id)[-6:]
|
friendly_remote_id = "..." + str(context.remote_id)[-6:]
|
||||||
|
|
||||||
message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
|
message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
|
||||||
if warning is None:
|
if warning is not None:
|
||||||
logger.info(message)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{message}: {warning}")
|
logger.warning(f"{message}: {warning}")
|
||||||
|
elif debug is not None:
|
||||||
|
logger.debug(f"{message}: {debug}")
|
||||||
|
else:
|
||||||
|
logger.info(message)
|
||||||
|
|
||||||
|
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
|
||||||
|
"""Return metadata about stored block uids and current load"""
|
||||||
|
|
||||||
async def _rpc_forward(
|
backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))
|
||||||
*flat_tensors: torch.Tensor,
|
result = {
|
||||||
requested_backends: Sequence[TransformerBackend],
|
"version": petals.__version__,
|
||||||
prioritizer: TaskPrioritizerBase,
|
"dht_client_mode": self.dht.client_mode,
|
||||||
points: int = 0,
|
CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()),
|
||||||
) -> torch.Tensor:
|
}
|
||||||
"""
|
|
||||||
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
||||||
|
|
||||||
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
if request.uid:
|
||||||
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
block_info = self.module_backends[request.uid].get_info()
|
||||||
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
common_keys = set(result.keys()) & set(block_info.keys())
|
||||||
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
if common_keys:
|
||||||
"""
|
raise RuntimeError(f"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}")
|
||||||
hidden_states, prompts = flat_tensors
|
result.update(block_info)
|
||||||
dtype = requested_backends[0].dtype
|
|
||||||
# check parse input tensors and cast dtypes
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
|
||||||
assert hidden_states.ndim == 3
|
|
||||||
if prompts is None or is_dummy(prompts):
|
|
||||||
prompts = [DUMMY] * len(requested_backends)
|
|
||||||
else:
|
|
||||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
||||||
|
|
||||||
# Run a chain of requested backends
|
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
|
||||||
for backend, prompt in zip(requested_backends, prompts):
|
|
||||||
if not is_dummy(prompt):
|
|
||||||
hidden_states[:, : prompt.shape[1]] += prompt
|
|
||||||
|
|
||||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
||||||
priority = prioritizer.prioritize(
|
|
||||||
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
|
||||||
)
|
|
||||||
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
||||||
hidden_states,
|
|
||||||
priority=priority,
|
|
||||||
)
|
|
||||||
assert isinstance(hidden_states, torch.Tensor)
|
|
||||||
assert (
|
|
||||||
hidden_states.ndim == 3
|
|
||||||
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
||||||
|
|
||||||
# Serialize the overall output
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
async def _rpc_backward(
|
|
||||||
*flat_tensors: torch.Tensor,
|
|
||||||
requested_backends: Sequence[TransformerBackend],
|
|
||||||
prioritizer: TaskPrioritizerBase,
|
|
||||||
points: int = 0,
|
|
||||||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
||||||
inputs, grad_outputs, prompts = flat_tensors
|
|
||||||
# Cast inputs & grad outputs to backend dtype
|
|
||||||
inputs = inputs.to(requested_backends[0].dtype)
|
|
||||||
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
||||||
|
|
||||||
if prompts is None or is_dummy(prompts):
|
|
||||||
prompts = [DUMMY] * len(requested_backends)
|
|
||||||
else:
|
|
||||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
||||||
|
|
||||||
# Run a forward chain to collect intermediate inputs
|
|
||||||
# Note that we do not forward for the last module since we do not need its output
|
|
||||||
inter_inputs = []
|
|
||||||
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
||||||
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
||||||
if not is_dummy(prompt):
|
|
||||||
inputs[:, : prompt.shape[1]] += prompt
|
|
||||||
inter_inputs.append(inputs)
|
|
||||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
||||||
priority = prioritizer.prioritize(
|
|
||||||
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
||||||
)
|
|
||||||
(inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
|
|
||||||
|
|
||||||
assert isinstance(inputs, torch.Tensor)
|
|
||||||
|
|
||||||
if not is_dummy(prompts[-1]):
|
|
||||||
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
||||||
inter_inputs.append(inputs)
|
|
||||||
|
|
||||||
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
||||||
grad_prompts_reversed = []
|
|
||||||
# Run a chain of requested backends
|
|
||||||
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
||||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
||||||
priority = prioritizer.prioritize(
|
|
||||||
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
||||||
)
|
|
||||||
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
|
|
||||||
|
|
||||||
assert isinstance(grad_outputs, torch.Tensor)
|
|
||||||
if not is_dummy(prompt):
|
|
||||||
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
||||||
|
|
||||||
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
||||||
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
|
||||||
|
@ -10,7 +10,7 @@ import ctypes
|
|||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import AsyncContextManager, Dict, Optional, Union
|
from typing import AsyncContextManager, Dict, Optional, Sequence
|
||||||
|
|
||||||
import hivemind
|
import hivemind
|
||||||
import torch
|
import torch
|
||||||
@ -18,7 +18,7 @@ from hivemind.utils import TensorDescriptor, get_logger
|
|||||||
|
|
||||||
from petals.utils.asyncio import shield_and_wait
|
from petals.utils.asyncio import shield_and_wait
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
Handle = int
|
Handle = int
|
||||||
|
|
||||||
@ -26,11 +26,10 @@ Handle = int
|
|||||||
class MemoryCache:
|
class MemoryCache:
|
||||||
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
||||||
|
|
||||||
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
|
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
|
||||||
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
||||||
self.alloc_timeout = alloc_timeout
|
self.alloc_timeout = alloc_timeout
|
||||||
self.device = device
|
self._lock_metadata = mp.Lock()
|
||||||
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
|
|
||||||
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
||||||
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
||||||
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
|
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
|
||||||
@ -48,6 +47,10 @@ class MemoryCache:
|
|||||||
def current_size_bytes(self, value: int):
|
def current_size_bytes(self, value: int):
|
||||||
self._current_size.value = value
|
self._current_size.value = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bytes_left(self) -> int:
|
||||||
|
return self.max_size_bytes - self.current_size_bytes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle_counter(self) -> int:
|
def handle_counter(self) -> int:
|
||||||
return self._handle_counter.value
|
return self._handle_counter.value
|
||||||
@ -57,26 +60,48 @@ class MemoryCache:
|
|||||||
self._handle_counter.value = value
|
self._handle_counter.value = value
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
|
async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
|
||||||
"""
|
"""
|
||||||
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
||||||
|
|
||||||
:param descr: allocate a tensor of this size, dtype, etc
|
:param descriptors: one or more tensors tensor of this size, dtype, etc
|
||||||
|
|
||||||
|
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
|
||||||
|
if not, it will count maximum tensor allocation across devices for the purposes of size limit
|
||||||
|
|
||||||
:note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
|
:note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
|
||||||
Furthermore, it can be called concurrently with at most one use_cache call in runtime.
|
Furthermore, it can be called concurrently with at most one use_cache call in runtime.
|
||||||
"""
|
"""
|
||||||
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
|
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
|
||||||
assert descr.device is None and descr
|
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
|
||||||
|
max_alloc_size = self.get_allocation_size(*descriptors)
|
||||||
|
|
||||||
alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
gib = 1024**3
|
||||||
alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
|
cur_size, max_size = self.current_size_bytes, self.max_size_bytes
|
||||||
|
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
||||||
|
logger.info(
|
||||||
|
f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), "
|
||||||
|
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
|
||||||
try:
|
try:
|
||||||
yield await shield_and_wait(alloc_task)
|
handles = await shield_and_wait(alloc_task)
|
||||||
|
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
||||||
|
yield handles
|
||||||
finally:
|
finally:
|
||||||
await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
|
self._free(max_alloc_size, alloc_task)
|
||||||
|
|
||||||
async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
|
@staticmethod
|
||||||
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
||||||
|
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
|
||||||
|
alloc_size_by_device = {}
|
||||||
|
for descr in descriptors:
|
||||||
|
tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
||||||
|
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
|
||||||
|
return max(alloc_size_by_device.values())
|
||||||
|
|
||||||
|
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
|
||||||
"""
|
"""
|
||||||
This method should be called inside asyncio.shield() because:
|
This method should be called inside asyncio.shield() because:
|
||||||
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
||||||
@ -86,26 +111,20 @@ class MemoryCache:
|
|||||||
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
||||||
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
||||||
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
||||||
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
with self._lock_metadata:
|
||||||
handle = int(self.handle_counter)
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
||||||
self.current_size_bytes += alloc_size
|
self.current_size_bytes += alloc_size
|
||||||
self.handle_counter += 1 # note: this will eventually overflow and it is okay
|
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
|
||||||
self._pipe_send.send((handle, descr))
|
self._pipe_send.send((handles, descriptors))
|
||||||
return handle
|
return handles
|
||||||
|
|
||||||
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
|
|
||||||
"""
|
|
||||||
This method should be called inside asyncio.shield() because:
|
|
||||||
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
||||||
- _schedule_free() must finish freeing memory even in case of cancellation
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
|
||||||
if alloc_task.exception() is not None:
|
if alloc_task.exception() is not None:
|
||||||
return
|
return
|
||||||
handle = alloc_task.result()
|
handles = alloc_task.result()
|
||||||
|
|
||||||
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
with self._lock_metadata:
|
||||||
self._pipe_send.send((handle, None)) # signal runtime to free that handle
|
self._pipe_send.send((handles, None)) # signal runtime to free these handles
|
||||||
self.current_size_bytes -= alloc_size
|
self.current_size_bytes -= alloc_size
|
||||||
self._memory_freed_event.set()
|
self._memory_freed_event.set()
|
||||||
|
|
||||||
@ -125,33 +144,32 @@ class MemoryCache:
|
|||||||
self._memory_freed_event.clear()
|
self._memory_freed_event.clear()
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_cache(self, handle: Handle) -> torch.Tensor:
|
def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Return a tensor that was previously allocated with try_allocate_cache,
|
Return one or more tensors previously allocated with allocate_cache,
|
||||||
|
|
||||||
:note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
|
:note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism.
|
||||||
However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
|
However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
|
||||||
"""
|
"""
|
||||||
assert os.getpid() == self.runtime_pid
|
assert os.getpid() == self.runtime_pid
|
||||||
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
||||||
|
|
||||||
with self._lock_metadata:
|
# read creation/deletion requests from connection handlers
|
||||||
# read creation/deletion requests from connection handlers
|
while self._pipe_recv.poll():
|
||||||
while self._pipe_recv.poll():
|
recv_handles, recv_data = self._pipe_recv.recv()
|
||||||
recv_handle, recv_data = self._pipe_recv.recv()
|
if recv_data is not None: # create new tensors
|
||||||
if isinstance(recv_data, TensorDescriptor):
|
assert len(recv_handles) == len(recv_data)
|
||||||
self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
|
for handle, descr in zip(recv_handles, recv_data):
|
||||||
elif recv_data is None:
|
self._allocated_tensors[handle] = descr.make_zeros()
|
||||||
if recv_handle not in self._allocated_tensors:
|
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
|
||||||
|
else: # delete tensors by handle
|
||||||
|
for handle in recv_handles:
|
||||||
|
if handle not in self._allocated_tensors:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
|
f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
|
||||||
)
|
)
|
||||||
self._allocated_tensors.pop(recv_handle, None)
|
self._allocated_tensors.pop(handle, None)
|
||||||
else:
|
yield tuple(self._allocated_tensors[handle] for handle in handles)
|
||||||
logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
|
|
||||||
|
|
||||||
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
|
|
||||||
yield self._allocated_tensors[handle]
|
|
||||||
|
|
||||||
|
|
||||||
class AllocationFailed(Exception):
|
class AllocationFailed(Exception):
|
||||||
|
164
src/petals/server/reachability.py
Normal file
164
src/petals/server/reachability.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import Future
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from hivemind.dht import DHT, DHTNode
|
||||||
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||||
|
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
|
||||||
|
from hivemind.proto import dht_pb2
|
||||||
|
from hivemind.utils import get_logger
|
||||||
|
|
||||||
|
from petals.constants import REACHABILITY_API_URL
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
|
||||||
|
"""verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
|
||||||
|
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
|
||||||
|
try:
|
||||||
|
r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
|
||||||
|
r.raise_for_status()
|
||||||
|
response = r.json()
|
||||||
|
|
||||||
|
if response["success"]:
|
||||||
|
logger.info("Server is reachable from the Internet. It will appear at https://health.petals.dev soon")
|
||||||
|
return
|
||||||
|
|
||||||
|
if attempt_no == 0:
|
||||||
|
# Usually, libp2p manages to set up relays before we finish loading blocks.
|
||||||
|
# In other cases, we may need to wait for up to `wait_time` seconds before it's done.
|
||||||
|
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Skipping reachability check because health.petals.dev is down: {repr(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Server has not become reachable from the Internet:\n\n"
|
||||||
|
f"{response['message']}\n\n"
|
||||||
|
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
|
||||||
|
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
|
||||||
|
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
|
||||||
|
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
|
||||||
|
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
|
||||||
|
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
|
||||||
|
"""test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""
|
||||||
|
|
||||||
|
async def _check_direct_reachability():
|
||||||
|
target_dht = await DHTNode.create(client_mode=True, **kwargs)
|
||||||
|
try:
|
||||||
|
protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)
|
||||||
|
async with protocol.serve(target_dht.protocol.p2p):
|
||||||
|
successes = requests = 0
|
||||||
|
for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
|
||||||
|
probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
|
||||||
|
if probe_available is None:
|
||||||
|
continue # remote peer failed to check probe
|
||||||
|
successes += probe_available
|
||||||
|
requests += 1
|
||||||
|
if requests >= max_peers:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug(f"Direct reachability: {successes}/{requests}")
|
||||||
|
return (successes / requests) >= threshold if requests > 0 else None
|
||||||
|
finally:
|
||||||
|
await target_dht.shutdown()
|
||||||
|
|
||||||
|
return RemoteExpertWorker.run_coroutine(_check_direct_reachability())
|
||||||
|
|
||||||
|
|
||||||
|
STRIPPED_PROBE_ARGS = dict(
|
||||||
|
dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReachabilityProtocol(ServicerBase):
|
||||||
|
"""Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""
|
||||||
|
|
||||||
|
def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
|
||||||
|
self.probe = probe
|
||||||
|
self.wait_timeout = wait_timeout
|
||||||
|
self._event_loop = self._stop = None
|
||||||
|
|
||||||
|
async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
|
||||||
|
"""Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
|
||||||
|
try:
|
||||||
|
request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
|
||||||
|
timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
|
||||||
|
response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
|
||||||
|
logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
|
||||||
|
return response.available
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
|
||||||
|
"""Help another peer to check its reachability"""
|
||||||
|
response = dht_pb2.PingResponse(available=True)
|
||||||
|
check_peer = PeerID(request.peer.node_id)
|
||||||
|
if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves
|
||||||
|
response.available = await self.call_check(check_peer, check_peer=check_peer) is True
|
||||||
|
logger.info(
|
||||||
|
f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, "
|
||||||
|
f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def serve(self, p2p: P2P):
|
||||||
|
try:
|
||||||
|
await self.add_p2p_handlers(p2p)
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
await self.remove_p2p_handlers(p2p)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]:
|
||||||
|
protocol = cls(**kwargs)
|
||||||
|
ready = Future()
|
||||||
|
|
||||||
|
async def _serve_with_probe():
|
||||||
|
try:
|
||||||
|
common_p2p = await dht.replicate_p2p()
|
||||||
|
protocol._event_loop = asyncio.get_event_loop()
|
||||||
|
protocol._stop = asyncio.Event()
|
||||||
|
|
||||||
|
initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]
|
||||||
|
for info in await common_p2p.list_peers():
|
||||||
|
initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
|
||||||
|
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
|
||||||
|
|
||||||
|
ready.set_result(True)
|
||||||
|
logger.info("Reachability service started")
|
||||||
|
|
||||||
|
async with protocol.serve(common_p2p):
|
||||||
|
await protocol._stop.wait()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Reachability service failed:", exc_info=True)
|
||||||
|
|
||||||
|
if not ready.done():
|
||||||
|
ready.set_exception(e)
|
||||||
|
finally:
|
||||||
|
if protocol is not None and protocol.probe is not None:
|
||||||
|
await protocol.probe.shutdown()
|
||||||
|
logger.debug("Reachability service shut down")
|
||||||
|
|
||||||
|
threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
|
||||||
|
if await_ready:
|
||||||
|
ready.result() # Propagates startup exceptions, if any
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if self._event_loop is not None and self._stop is not None:
|
||||||
|
self._event_loop.call_soon_threadsafe(self._stop.set)
|
@ -6,33 +6,36 @@ import multiprocessing as mp
|
|||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import hivemind
|
||||||
import psutil
|
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
||||||
from hivemind.moe.server.layers import add_custom_models_from_file
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
||||||
from hivemind.moe.server.runtime import Runtime
|
from hivemind.moe.server.runtime import Runtime
|
||||||
from hivemind.proto.runtime_pb2 import CompressionType
|
from hivemind.proto.runtime_pb2 import CompressionType
|
||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
from transformers import BloomConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
import petals
|
||||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
|
||||||
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
||||||
from petals.server import block_selection
|
from petals.server import block_selection
|
||||||
from petals.server.backend import TransformerBackend
|
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
||||||
from petals.server.block_utils import get_block_size
|
from petals.server.block_utils import get_block_size, resolve_block_dtype
|
||||||
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
from petals.server.handler import TransformerConnectionHandler
|
from petals.server.handler import TransformerConnectionHandler
|
||||||
from petals.server.memory_cache import MemoryCache
|
from petals.server.memory_cache import MemoryCache
|
||||||
from petals.server.throughput import get_host_throughput
|
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
|
||||||
from petals.utils.convert_8bit import replace_8bit_linear
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
||||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
||||||
|
from petals.utils.ping import PingAggregator
|
||||||
|
from petals.utils.random import sample_up_to
|
||||||
|
from petals.utils.version import get_compatible_model_repo
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
@ -45,26 +48,28 @@ class Server:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
initial_peers: List[str],
|
initial_peers: List[str],
|
||||||
prefix: Optional[str],
|
dht_prefix: Optional[str],
|
||||||
converted_model_name_or_path: str,
|
converted_model_name_or_path: str,
|
||||||
|
public_name: Optional[str] = None,
|
||||||
throughput: Union[float, str],
|
throughput: Union[float, str],
|
||||||
num_blocks: Optional[int] = None,
|
num_blocks: Optional[int] = None,
|
||||||
block_indices: Optional[str] = None,
|
block_indices: Optional[str] = None,
|
||||||
num_handlers: int = 8,
|
num_handlers: int = 8,
|
||||||
|
inference_max_length: Optional[int] = None,
|
||||||
min_batch_size: int = 1,
|
min_batch_size: int = 1,
|
||||||
max_batch_size: int = 2048,
|
max_batch_size: Optional[int] = None,
|
||||||
inference_max_length: int = 2048,
|
max_chunk_size_bytes: int = 256 * 1024 * 1024,
|
||||||
|
attn_cache_tokens: Optional[int] = None,
|
||||||
torch_dtype: str = "auto",
|
torch_dtype: str = "auto",
|
||||||
revision: str = "main",
|
revision: Optional[str] = None,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
max_disk_space: Optional[int] = None,
|
max_disk_space: Optional[int] = None,
|
||||||
attn_cache_size: Optional[int] = None,
|
alloc_timeout: float = 5,
|
||||||
alloc_timeout: float = 60,
|
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
compression=CompressionType.NONE,
|
compression=CompressionType.NONE,
|
||||||
stats_report_interval: Optional[int] = None,
|
stats_report_interval: Optional[int] = None,
|
||||||
custom_module_path=None,
|
custom_module_path=None,
|
||||||
update_period: float = 150,
|
update_period: float = 60,
|
||||||
expiration: Optional[float] = None,
|
expiration: Optional[float] = None,
|
||||||
request_timeout: float = 3 * 60,
|
request_timeout: float = 3 * 60,
|
||||||
session_timeout: float = 30 * 60,
|
session_timeout: float = 30 * 60,
|
||||||
@ -73,34 +78,44 @@ class Server:
|
|||||||
sender_threads: int = 1,
|
sender_threads: int = 1,
|
||||||
balance_quality: float = 0.75,
|
balance_quality: float = 0.75,
|
||||||
mean_balance_check_period: float = 120,
|
mean_balance_check_period: float = 120,
|
||||||
mean_block_selection_delay: float = 2.5,
|
mean_block_selection_delay: float = 5,
|
||||||
use_auth_token: Optional[str] = None,
|
token: Optional[Union[str, bool]] = None,
|
||||||
load_in_8bit: Optional[bool] = None,
|
quant_type: Optional[QuantType] = None,
|
||||||
|
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
|
||||||
skip_reachability_check: bool = False,
|
skip_reachability_check: bool = False,
|
||||||
|
reachable_via_relay: Optional[bool] = None,
|
||||||
|
use_relay: bool = True,
|
||||||
|
use_auto_relay: bool = True,
|
||||||
|
adapters: Sequence[str] = (),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
||||||
|
|
||||||
|
converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path)
|
||||||
self.converted_model_name_or_path = converted_model_name_or_path
|
self.converted_model_name_or_path = converted_model_name_or_path
|
||||||
|
|
||||||
self.num_handlers = num_handlers
|
self.num_handlers = num_handlers
|
||||||
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
||||||
self.inference_max_length = inference_max_length
|
|
||||||
self.compression = compression
|
self.compression = compression
|
||||||
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
||||||
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
||||||
self.use_auth_token = use_auth_token
|
self.revision, self.token = revision, token
|
||||||
|
|
||||||
if custom_module_path is not None:
|
if custom_module_path is not None:
|
||||||
add_custom_models_from_file(custom_module_path)
|
add_custom_models_from_file(custom_module_path)
|
||||||
|
|
||||||
if prefix is None:
|
self.block_config = AutoDistributedConfig.from_pretrained(
|
||||||
prefix = converted_model_name_or_path
|
converted_model_name_or_path,
|
||||||
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
|
use_auth_token=token,
|
||||||
f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
|
revision=revision,
|
||||||
f"Please specify --prefix manually when starting a server"
|
)
|
||||||
)
|
|
||||||
logger.info(f"Automatic dht prefix: {prefix}")
|
if dht_prefix is None:
|
||||||
self.prefix = prefix
|
dht_prefix = self.block_config.dht_prefix
|
||||||
|
assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, (
|
||||||
|
f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
|
||||||
|
f"Please specify another --dht_prefix manually when starting a server"
|
||||||
|
)
|
||||||
|
self.dht_prefix = dht_prefix
|
||||||
|
|
||||||
if expiration is None:
|
if expiration is None:
|
||||||
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
||||||
@ -109,75 +124,127 @@ class Server:
|
|||||||
self.request_timeout = request_timeout
|
self.request_timeout = request_timeout
|
||||||
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
||||||
|
|
||||||
self.block_config = BloomConfig.from_pretrained(
|
self.module_uids = [
|
||||||
converted_model_name_or_path,
|
f"{self.dht_prefix}{UID_DELIMITER}{block_index}"
|
||||||
use_auth_token=use_auth_token,
|
for block_index in range(self.block_config.num_hidden_layers)
|
||||||
revision=revision,
|
]
|
||||||
)
|
|
||||||
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
if reachable_via_relay is None:
|
||||||
|
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
|
||||||
|
reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer
|
||||||
|
logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
|
||||||
|
self.dht = DHT(
|
||||||
|
initial_peers=initial_peers,
|
||||||
|
start=True,
|
||||||
|
num_workers=self.block_config.num_hidden_layers,
|
||||||
|
use_relay=use_relay,
|
||||||
|
use_auto_relay=use_auto_relay,
|
||||||
|
client_mode=reachable_via_relay,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
|
||||||
|
|
||||||
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
|
|
||||||
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
||||||
if initial_peers == PUBLIC_INITIAL_PEERS:
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
||||||
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
|
logger.info("Connecting to the public swarm")
|
||||||
if not skip_reachability_check:
|
|
||||||
self._check_reachability()
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
logger.info(f"Connecting to a private swarm, initial peers: {initial_peers}")
|
||||||
|
logger.info(f"Running a server on {visible_maddrs_str}")
|
||||||
|
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
if device.type == "cuda" and device.index is None:
|
||||||
|
device = torch.device(device.type, index=0)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
if isinstance(torch_dtype, str):
|
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
|
||||||
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
||||||
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
||||||
self.torch_dtype = torch_dtype
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
if load_in_8bit is None:
|
if tensor_parallel_devices is None:
|
||||||
load_in_8bit = device.type == "cuda"
|
tensor_parallel_devices = (device,)
|
||||||
if load_in_8bit:
|
self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
|
||||||
logger.info("Model weights will be loaded in 8-bit format")
|
if len(self.tensor_parallel_devices) > 1:
|
||||||
self.load_in_8bit = load_in_8bit
|
logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
|
||||||
|
check_device_balance(self.tensor_parallel_devices)
|
||||||
|
|
||||||
|
if quant_type is None:
|
||||||
|
if device.type == "cuda":
|
||||||
|
quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
|
||||||
|
else:
|
||||||
|
quant_type = QuantType.NONE
|
||||||
|
self.quant_type = quant_type
|
||||||
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
||||||
|
|
||||||
|
is_multiquery_attn = self.block_config.num_key_value_groups > 1
|
||||||
|
if max_batch_size is None:
|
||||||
|
max_batch_size = 8192 if is_multiquery_attn else 2048
|
||||||
|
if inference_max_length is None:
|
||||||
|
inference_max_length = 8192 if is_multiquery_attn else 2048
|
||||||
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
||||||
|
self.inference_max_length = inference_max_length
|
||||||
|
self.max_chunk_size_bytes = max_chunk_size_bytes
|
||||||
|
|
||||||
|
# For attention cache in GPU or RAM
|
||||||
|
if attn_cache_tokens is None:
|
||||||
|
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
|
||||||
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
||||||
|
cache_values_per_block //= self.block_config.num_key_value_groups
|
||||||
|
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
||||||
|
|
||||||
|
# For disk cache
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.max_disk_space = max_disk_space
|
||||||
|
self.adapters = adapters
|
||||||
|
|
||||||
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
||||||
if num_blocks is None and block_indices is None:
|
if num_blocks is None and block_indices is None:
|
||||||
num_blocks = self._choose_num_blocks()
|
num_blocks = self._choose_num_blocks()
|
||||||
|
if num_blocks is not None:
|
||||||
|
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
|
||||||
if block_indices is not None:
|
if block_indices is not None:
|
||||||
try:
|
try:
|
||||||
first_block_index, last_block_index = block_indices.split(":")
|
first_block_index, last_block_index = block_indices.split(":")
|
||||||
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
|
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
|
||||||
raise
|
|
||||||
block_indices = range(first_block_index, last_block_index)
|
block_indices = range(first_block_index, last_block_index)
|
||||||
num_blocks = len(block_indices)
|
num_blocks = len(block_indices)
|
||||||
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
||||||
|
|
||||||
gib = 1024**3
|
gib = 1024**3
|
||||||
if attn_cache_size is None:
|
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
||||||
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
|
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
||||||
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
|
|
||||||
self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
|
|
||||||
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
||||||
|
|
||||||
if cache_dir is None:
|
self.alloc_timeout = alloc_timeout
|
||||||
cache_dir = DEFAULT_CACHE_DIR
|
|
||||||
self.cache_dir = cache_dir
|
|
||||||
self.max_disk_space = max_disk_space
|
|
||||||
|
|
||||||
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
||||||
if throughput in ["auto", "eval"]:
|
if throughput in ["auto", "eval"]:
|
||||||
throughput = get_host_throughput(
|
throughput_info = get_server_throughput(
|
||||||
|
converted_model_name_or_path,
|
||||||
self.block_config,
|
self.block_config,
|
||||||
device,
|
device,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
load_in_8bit=load_in_8bit,
|
num_blocks=num_blocks,
|
||||||
|
quant_type=quant_type,
|
||||||
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
||||||
|
reachable_via_relay=reachable_via_relay,
|
||||||
force_eval=(throughput == "eval"),
|
force_eval=(throughput == "eval"),
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
self.throughput = throughput
|
else:
|
||||||
|
throughput_info = {"throughput": throughput}
|
||||||
|
self.server_info = ServerInfo(
|
||||||
|
state=ServerState.JOINING,
|
||||||
|
public_name=public_name,
|
||||||
|
version=petals.__version__,
|
||||||
|
adapters=tuple(adapters),
|
||||||
|
torch_dtype=str(torch_dtype).replace("torch.", ""),
|
||||||
|
quant_type=quant_type.name.lower(),
|
||||||
|
using_relay=reachable_via_relay,
|
||||||
|
**throughput_info,
|
||||||
|
)
|
||||||
|
|
||||||
self.balance_quality = balance_quality
|
self.balance_quality = balance_quality
|
||||||
self.mean_balance_check_period = mean_balance_check_period
|
self.mean_balance_check_period = mean_balance_check_period
|
||||||
@ -185,65 +252,72 @@ class Server:
|
|||||||
|
|
||||||
self.stop = threading.Event()
|
self.stop = threading.Event()
|
||||||
|
|
||||||
def _check_reachability(self):
|
def _choose_num_blocks(self) -> int:
|
||||||
try:
|
assert self.device.type == "cuda", (
|
||||||
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
|
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
|
||||||
r.raise_for_status()
|
"CPU-only servers in the public swarm are discouraged since they are much slower"
|
||||||
response = r.json()
|
)
|
||||||
except Exception as e:
|
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
|
||||||
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not response["success"]:
|
if num_devices > 1:
|
||||||
# This happens only if health.petals.ml is up and explicitly told us that we are unreachable
|
memory_per_device = tuple(
|
||||||
raise RuntimeError(
|
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
|
||||||
f"Server is not reachable from the Internet:\n\n"
|
)
|
||||||
f"{response['message']}\n\n"
|
total_memory = min(memory_per_device) * num_devices
|
||||||
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
|
if max(memory_per_device) / min(memory_per_device) > 1.5:
|
||||||
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
|
raise ValueError(
|
||||||
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
|
"GPU devices have highly uneven memory, which makes tensor parallelism inefficient. "
|
||||||
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
|
"Please launch individual servers on each GPU or set --num_blocks manually to "
|
||||||
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
|
"override this exception."
|
||||||
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
|
)
|
||||||
|
else:
|
||||||
|
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
|
||||||
|
gib = 1024**3
|
||||||
|
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
|
||||||
|
autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
|
||||||
|
|
||||||
|
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
|
||||||
|
total_memory_per_block = block_size + self._cache_bytes_per_block
|
||||||
|
if self.adapters:
|
||||||
|
# Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
|
||||||
|
from petals.utils.peft import estimate_adapter_memory_per_block
|
||||||
|
|
||||||
|
total_memory_per_block += estimate_adapter_memory_per_block(
|
||||||
|
self.block_config,
|
||||||
|
self.torch_dtype,
|
||||||
|
self.adapters,
|
||||||
|
token=self.token,
|
||||||
|
cache_dir=self.cache_dir,
|
||||||
|
max_disk_space=self.max_disk_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon")
|
num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
|
||||||
|
|
||||||
def _choose_num_blocks(self) -> int:
|
|
||||||
assert (
|
|
||||||
self.converted_model_name_or_path == "bigscience/bloom-petals"
|
|
||||||
), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
|
|
||||||
assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
|
|
||||||
|
|
||||||
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
||||||
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
|
|
||||||
gib = 1024**3
|
|
||||||
attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size
|
|
||||||
|
|
||||||
num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
|
|
||||||
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
||||||
|
|
||||||
|
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
|
f"Server will fill your GPU memory with {num_blocks} transformer blocks. "
|
||||||
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
|
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
|
||||||
)
|
)
|
||||||
return min(num_blocks, self.block_config.n_layer)
|
return num_blocks
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
block_indices = self._choose_blocks()
|
block_indices = self._choose_blocks()
|
||||||
self.module_container = ModuleContainer.create(
|
self.module_container = ModuleContainer.create(
|
||||||
dht=self.dht,
|
dht=self.dht,
|
||||||
prefix=self.prefix,
|
dht_prefix=self.dht_prefix,
|
||||||
converted_model_name_or_path=self.converted_model_name_or_path,
|
converted_model_name_or_path=self.converted_model_name_or_path,
|
||||||
block_config=self.block_config,
|
block_config=self.block_config,
|
||||||
attn_cache_size=self.attn_cache_size,
|
attn_cache_bytes=self.attn_cache_bytes,
|
||||||
alloc_timeout=self.alloc_timeout,
|
alloc_timeout=self.alloc_timeout,
|
||||||
throughput=self.throughput,
|
server_info=self.server_info,
|
||||||
block_indices=block_indices,
|
block_indices=block_indices,
|
||||||
num_handlers=self.num_handlers,
|
num_handlers=self.num_handlers,
|
||||||
min_batch_size=self.min_batch_size,
|
min_batch_size=self.min_batch_size,
|
||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=self.max_batch_size,
|
||||||
|
max_chunk_size_bytes=self.max_chunk_size_bytes,
|
||||||
inference_max_length=self.inference_max_length,
|
inference_max_length=self.inference_max_length,
|
||||||
torch_dtype=self.torch_dtype,
|
torch_dtype=self.torch_dtype,
|
||||||
cache_dir=self.cache_dir,
|
cache_dir=self.cache_dir,
|
||||||
@ -258,8 +332,11 @@ class Server:
|
|||||||
step_timeout=self.step_timeout,
|
step_timeout=self.step_timeout,
|
||||||
prefetch_batches=self.prefetch_batches,
|
prefetch_batches=self.prefetch_batches,
|
||||||
sender_threads=self.sender_threads,
|
sender_threads=self.sender_threads,
|
||||||
use_auth_token=self.use_auth_token,
|
revision=self.revision,
|
||||||
load_in_8bit=self.load_in_8bit,
|
token=self.token,
|
||||||
|
quant_type=self.quant_type,
|
||||||
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
||||||
|
should_validate_reachability=self.should_validate_reachability,
|
||||||
start=True,
|
start=True,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -286,10 +363,6 @@ class Server:
|
|||||||
del self.module_container
|
del self.module_container
|
||||||
gc.collect() # In particular, this closes unused file descriptors
|
gc.collect() # In particular, this closes unused file descriptors
|
||||||
|
|
||||||
cur_proc = psutil.Process()
|
|
||||||
num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
|
|
||||||
logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
|
|
||||||
|
|
||||||
if self.device.type == "cuda":
|
if self.device.type == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -308,19 +381,21 @@ class Server:
|
|||||||
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
||||||
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
||||||
time.sleep(random.random() * 2 * self.mean_block_selection_delay)
|
time.sleep(random.random() * 2 * self.mean_block_selection_delay)
|
||||||
module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
|
module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
|
||||||
return block_selection.choose_best_blocks(self.num_blocks, module_infos)
|
return block_selection.choose_best_blocks(self.num_blocks, module_infos)
|
||||||
|
|
||||||
def _should_choose_other_blocks(self) -> bool:
|
def _should_choose_other_blocks(self) -> bool:
|
||||||
if self.strict_block_indices is not None:
|
if self.strict_block_indices is not None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
|
module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
|
||||||
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
|
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.stop.set()
|
self.stop.set()
|
||||||
|
|
||||||
|
if self.reachability_protocol is not None:
|
||||||
|
self.reachability_protocol.shutdown()
|
||||||
self.dht.shutdown()
|
self.dht.shutdown()
|
||||||
self.dht.join()
|
self.dht.join()
|
||||||
|
|
||||||
@ -334,15 +409,16 @@ class ModuleContainer(threading.Thread):
|
|||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
dht: DHT,
|
dht: DHT,
|
||||||
prefix: str,
|
dht_prefix: str,
|
||||||
converted_model_name_or_path: str,
|
converted_model_name_or_path: str,
|
||||||
block_config: BloomConfig,
|
block_config: PretrainedConfig,
|
||||||
attn_cache_size: int,
|
attn_cache_bytes: int,
|
||||||
alloc_timeout: float,
|
alloc_timeout: float,
|
||||||
throughput: float,
|
server_info: ServerInfo,
|
||||||
block_indices: List[int],
|
block_indices: List[int],
|
||||||
min_batch_size: int,
|
min_batch_size: int,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
|
max_chunk_size_bytes: int,
|
||||||
torch_dtype: torch.dtype,
|
torch_dtype: torch.dtype,
|
||||||
cache_dir: str,
|
cache_dir: str,
|
||||||
max_disk_space: int,
|
max_disk_space: int,
|
||||||
@ -350,89 +426,99 @@ class ModuleContainer(threading.Thread):
|
|||||||
compression: CompressionType,
|
compression: CompressionType,
|
||||||
update_period: float,
|
update_period: float,
|
||||||
expiration: Optional[float],
|
expiration: Optional[float],
|
||||||
use_auth_token: Optional[str],
|
revision: Optional[str],
|
||||||
load_in_8bit: bool,
|
token: Optional[Union[str, bool]],
|
||||||
|
quant_type: QuantType,
|
||||||
|
tensor_parallel_devices: Sequence[torch.device],
|
||||||
|
should_validate_reachability: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModuleContainer:
|
) -> ModuleContainer:
|
||||||
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
||||||
joining_announcer = ModuleAnnouncerThread(
|
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
||||||
|
|
||||||
|
server_info.state = ServerState.JOINING
|
||||||
|
dht_announcer = ModuleAnnouncerThread(
|
||||||
module_uids,
|
module_uids,
|
||||||
dht,
|
dht,
|
||||||
ServerState.JOINING,
|
server_info,
|
||||||
throughput=throughput,
|
block_config=block_config,
|
||||||
|
memory_cache=memory_cache,
|
||||||
update_period=update_period,
|
update_period=update_period,
|
||||||
expiration=expiration,
|
expiration=expiration,
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
joining_announcer.start()
|
dht_announcer.start()
|
||||||
logger.info(f"Announced that blocks {block_indices} are joining")
|
logger.info(f"Announced that blocks {block_indices} are joining")
|
||||||
|
|
||||||
memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
|
||||||
|
|
||||||
blocks = {}
|
blocks = {}
|
||||||
try:
|
try:
|
||||||
for module_uid, block_index in zip(module_uids, block_indices):
|
for module_uid, block_index in zip(module_uids, block_indices):
|
||||||
block = load_pretrained_block(
|
block = load_pretrained_block(
|
||||||
converted_model_name_or_path,
|
converted_model_name_or_path,
|
||||||
block_index,
|
block_index,
|
||||||
block_config,
|
config=block_config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
use_auth_token=use_auth_token,
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
|
)
|
||||||
|
block = convert_block(
|
||||||
|
block,
|
||||||
|
block_index,
|
||||||
|
block_config,
|
||||||
|
tensor_parallel_devices,
|
||||||
|
device,
|
||||||
|
quant_type,
|
||||||
|
adapters=server_info.adapters,
|
||||||
|
freeze=True,
|
||||||
|
token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
max_disk_space=max_disk_space,
|
max_disk_space=max_disk_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
if load_in_8bit:
|
|
||||||
block = replace_8bit_linear(block)
|
|
||||||
|
|
||||||
block = block.to(device)
|
|
||||||
for param in block.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
|
|
||||||
blocks[module_uid] = TransformerBackend(
|
blocks[module_uid] = TransformerBackend(
|
||||||
module_uid,
|
module_uid,
|
||||||
block,
|
block,
|
||||||
|
config=block_config,
|
||||||
memory_cache=memory_cache,
|
memory_cache=memory_cache,
|
||||||
backend_dtype=backend_dtype,
|
backend_dtype=torch_dtype,
|
||||||
|
max_chunk_size_bytes=max_chunk_size_bytes,
|
||||||
args_schema=(
|
args_schema=(
|
||||||
BatchTensorDescriptor(
|
BatchTensorDescriptor(
|
||||||
1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
|
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
kwargs_schema={},
|
kwargs_schema={},
|
||||||
outputs_schema=(
|
outputs_schema=(
|
||||||
BatchTensorDescriptor(
|
BatchTensorDescriptor(
|
||||||
1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
|
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
min_batch_size=min_batch_size,
|
min_batch_size=min_batch_size,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
merge_inference_pools_inplace(blocks)
|
||||||
|
|
||||||
|
if should_validate_reachability:
|
||||||
|
validate_reachability(dht.peer_id)
|
||||||
except:
|
except:
|
||||||
logger.debug("Shutting down backends")
|
logger.debug("Shutting down backends")
|
||||||
for backend in blocks.values():
|
for backend in blocks.values():
|
||||||
backend.shutdown()
|
backend.shutdown()
|
||||||
|
|
||||||
joining_announcer.stop.set()
|
dht_announcer.announce(ServerState.OFFLINE)
|
||||||
joining_announcer.join()
|
|
||||||
declare_active_modules(
|
|
||||||
dht,
|
|
||||||
module_uids,
|
|
||||||
expiration_time=get_dht_time() + expiration,
|
|
||||||
state=ServerState.OFFLINE,
|
|
||||||
throughput=throughput,
|
|
||||||
)
|
|
||||||
logger.info(f"Announced that blocks {module_uids} are offline")
|
logger.info(f"Announced that blocks {module_uids} are offline")
|
||||||
raise
|
raise
|
||||||
else:
|
|
||||||
joining_announcer.stop.set()
|
|
||||||
joining_announcer.join()
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
dht,
|
dht,
|
||||||
|
dht_prefix,
|
||||||
blocks,
|
blocks,
|
||||||
throughput=throughput,
|
dht_announcer=dht_announcer,
|
||||||
device=device,
|
server_info=server_info,
|
||||||
update_period=update_period,
|
update_period=update_period,
|
||||||
expiration=expiration,
|
expiration=expiration,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -441,11 +527,13 @@ class ModuleContainer(threading.Thread):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dht: DHT,
|
dht: DHT,
|
||||||
|
dht_prefix: str,
|
||||||
module_backends: Dict[str, TransformerBackend],
|
module_backends: Dict[str, TransformerBackend],
|
||||||
*,
|
*,
|
||||||
inference_max_length: int,
|
inference_max_length: int,
|
||||||
num_handlers: int,
|
num_handlers: int,
|
||||||
throughput: float,
|
dht_announcer: ModuleAnnouncerThread,
|
||||||
|
server_info: ServerInfo,
|
||||||
update_period: float,
|
update_period: float,
|
||||||
expiration: Optional[float] = None,
|
expiration: Optional[float] = None,
|
||||||
request_timeout: float,
|
request_timeout: float,
|
||||||
@ -457,29 +545,31 @@ class ModuleContainer(threading.Thread):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dht, self.module_backends = dht, module_backends
|
self.dht, self.module_backends = dht, module_backends
|
||||||
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
|
||||||
|
|
||||||
|
handler_event_queues = [mp.Queue() for _ in range(num_handlers)]
|
||||||
self.conn_handlers = [
|
self.conn_handlers = [
|
||||||
TransformerConnectionHandler(
|
TransformerConnectionHandler(
|
||||||
dht,
|
dht,
|
||||||
self.module_backends,
|
self.module_backends,
|
||||||
|
adapters=server_info.adapters,
|
||||||
|
dht_prefix=dht_prefix,
|
||||||
|
handler_event_queues=handler_event_queues,
|
||||||
|
handler_index=i,
|
||||||
inference_max_length=inference_max_length,
|
inference_max_length=inference_max_length,
|
||||||
request_timeout=request_timeout,
|
request_timeout=request_timeout,
|
||||||
session_timeout=session_timeout,
|
session_timeout=session_timeout,
|
||||||
step_timeout=step_timeout,
|
step_timeout=step_timeout,
|
||||||
|
quant_type=QuantType[server_info.quant_type.upper()],
|
||||||
)
|
)
|
||||||
for _ in range(num_handlers)
|
for i in range(num_handlers)
|
||||||
]
|
]
|
||||||
self.runtime = Runtime(self.module_backends, **kwargs)
|
|
||||||
self.online_announcer = ModuleAnnouncerThread(
|
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
|
||||||
list(self.module_backends.keys()),
|
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
|
||||||
dht,
|
|
||||||
ServerState.ONLINE,
|
dht_announcer.announce(ServerState.ONLINE)
|
||||||
throughput=throughput,
|
self.dht_announcer = dht_announcer
|
||||||
update_period=update_period,
|
|
||||||
expiration=expiration,
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
||||||
|
|
||||||
if start:
|
if start:
|
||||||
self.run_in_background(await_ready=True)
|
self.run_in_background(await_ready=True)
|
||||||
@ -489,14 +579,6 @@ class ModuleContainer(threading.Thread):
|
|||||||
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
|
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
|
||||||
runs Runtime (self.runtime) to process incoming requests.
|
runs Runtime (self.runtime) to process incoming requests.
|
||||||
"""
|
"""
|
||||||
if not self.dht.is_alive():
|
|
||||||
self.dht.run_in_background(await_ready=True)
|
|
||||||
|
|
||||||
self.online_announcer.start()
|
|
||||||
|
|
||||||
if self.checkpoint_saver is not None:
|
|
||||||
self.checkpoint_saver.start()
|
|
||||||
|
|
||||||
for handler in self.conn_handlers:
|
for handler in self.conn_handlers:
|
||||||
handler.run_in_background()
|
handler.run_in_background()
|
||||||
|
|
||||||
@ -535,27 +617,14 @@ class ModuleContainer(threading.Thread):
|
|||||||
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
||||||
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
||||||
"""
|
"""
|
||||||
self.online_announcer.stop.set()
|
self.dht_announcer.announce(ServerState.OFFLINE)
|
||||||
self.online_announcer.join()
|
|
||||||
|
|
||||||
declare_active_modules(
|
|
||||||
self.dht,
|
|
||||||
self.module_backends.keys(),
|
|
||||||
expiration_time=get_dht_time() + self.expiration,
|
|
||||||
state=ServerState.OFFLINE,
|
|
||||||
throughput=self.throughput,
|
|
||||||
)
|
|
||||||
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
||||||
|
|
||||||
self.ready.clear()
|
self.ready.clear()
|
||||||
|
|
||||||
|
logger.debug("Shutting down connection handlers")
|
||||||
for handler in self.conn_handlers:
|
for handler in self.conn_handlers:
|
||||||
handler.shutdown()
|
handler.shutdown()
|
||||||
logger.debug("Connection handlers terminated")
|
|
||||||
|
|
||||||
if self.checkpoint_saver is not None:
|
|
||||||
self.checkpoint_saver.stop.set()
|
|
||||||
self.checkpoint_saver.join()
|
|
||||||
|
|
||||||
logger.debug(f"Shutting down pools")
|
logger.debug(f"Shutting down pools")
|
||||||
for pool in self.runtime.pools:
|
for pool in self.runtime.pools:
|
||||||
@ -579,30 +648,85 @@ class ModuleAnnouncerThread(threading.Thread):
|
|||||||
self,
|
self,
|
||||||
module_uids: List[str],
|
module_uids: List[str],
|
||||||
dht: DHT,
|
dht: DHT,
|
||||||
state: ServerState,
|
server_info: ServerInfo,
|
||||||
*,
|
*,
|
||||||
throughput: float,
|
block_config: PretrainedConfig,
|
||||||
update_period: float = 30,
|
memory_cache: MemoryCache,
|
||||||
|
update_period: float,
|
||||||
expiration: float,
|
expiration: float,
|
||||||
|
max_pinged: int = 5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.module_uids = module_uids
|
self.module_uids = module_uids
|
||||||
self.dht = dht
|
self.dht = dht
|
||||||
self.state = state
|
self.server_info = server_info
|
||||||
self.throughput = throughput
|
self.memory_cache = memory_cache
|
||||||
|
|
||||||
|
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
|
||||||
|
self.bytes_per_token //= block_config.num_key_value_groups
|
||||||
|
|
||||||
self.update_period = update_period
|
self.update_period = update_period
|
||||||
self.expiration = expiration
|
self.expiration = expiration
|
||||||
self.stop = threading.Event()
|
self.trigger = threading.Event()
|
||||||
|
|
||||||
|
self.max_pinged = max_pinged
|
||||||
|
dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
|
||||||
|
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
|
||||||
|
start_block, end_block = min(block_indices), max(block_indices) + 1
|
||||||
|
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
|
||||||
|
self.ping_aggregator = PingAggregator(self.dht)
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
|
||||||
|
if self.server_info.state != ServerState.OFFLINE:
|
||||||
|
self._ping_next_servers()
|
||||||
|
self.server_info.next_pings = {
|
||||||
|
peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.server_info.next_pings = None # No need to ping if we're disconnecting
|
||||||
|
|
||||||
declare_active_modules(
|
declare_active_modules(
|
||||||
self.dht,
|
self.dht,
|
||||||
self.module_uids,
|
self.module_uids,
|
||||||
|
self.server_info,
|
||||||
expiration_time=get_dht_time() + self.expiration,
|
expiration_time=get_dht_time() + self.expiration,
|
||||||
state=self.state,
|
|
||||||
throughput=self.throughput,
|
|
||||||
)
|
)
|
||||||
if self.stop.wait(self.update_period):
|
if self.server_info.state == ServerState.OFFLINE:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
delay = self.update_period - (time.perf_counter() - start_time)
|
||||||
|
if delay < 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
|
||||||
|
)
|
||||||
|
self.trigger.wait(max(delay, 0))
|
||||||
|
self.trigger.clear()
|
||||||
|
|
||||||
|
def announce(self, state: ServerState) -> None:
|
||||||
|
self.server_info.state = state
|
||||||
|
self.trigger.set()
|
||||||
|
if state == ServerState.OFFLINE:
|
||||||
|
self.join()
|
||||||
|
|
||||||
|
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
|
||||||
|
module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
|
||||||
|
middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
|
||||||
|
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
|
||||||
|
pinged_servers.discard(self.dht.peer_id)
|
||||||
|
if module_infos[-1] is not None:
|
||||||
|
# Sample servers hosting the block after the last one (most likely continuations) separately
|
||||||
|
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
|
||||||
|
self.ping_aggregator.ping(list(pinged_servers))
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeWithDeduplicatedPools(Runtime):
|
||||||
|
"""A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.pools = tuple(set(self.pools))
|
||||||
|
@ -5,14 +5,14 @@ import time
|
|||||||
from concurrent.futures._base import PENDING
|
from concurrent.futures._base import PENDING
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from queue import PriorityQueue
|
from queue import PriorityQueue
|
||||||
from typing import Any, List, Optional, Sequence, Tuple
|
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from hivemind import get_logger
|
from hivemind import get_logger
|
||||||
from hivemind.moe.server.task_pool import TaskPoolBase
|
from hivemind.moe.server.task_pool import TaskPoolBase
|
||||||
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
|
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(order=True, frozen=True)
|
@dataclass(order=True, frozen=True)
|
||||||
@ -43,6 +43,7 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|||||||
|
|
||||||
:param name: pool name, used for logging
|
:param name: pool name, used for logging
|
||||||
:param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
|
:param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
|
||||||
|
:param device: if specified, input tensors will be moved to that device by default
|
||||||
:param start: if True, start automatically at the end of __init__
|
:param start: if True, start automatically at the end of __init__
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -52,11 +53,13 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
name: str,
|
name: str,
|
||||||
min_batch_size=1,
|
min_batch_size=1,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
daemon=True,
|
daemon=True,
|
||||||
start=False,
|
start=False,
|
||||||
):
|
):
|
||||||
super().__init__(process_func, daemon=daemon, name=name)
|
super().__init__(process_func, daemon=daemon, name=name)
|
||||||
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
||||||
|
self.device = device
|
||||||
|
|
||||||
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
|
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
|
||||||
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
|
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
|
||||||
@ -101,7 +104,7 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|||||||
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
||||||
self.terminate()
|
self.terminate()
|
||||||
|
|
||||||
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
|
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
|
||||||
"""Add task to this pool's queue, return Future for its output"""
|
"""Add task to this pool's queue, return Future for its output"""
|
||||||
future = MPFuture()
|
future = MPFuture()
|
||||||
# Remove shmem from MPFuture. This disables the .cancel() feature but
|
# Remove shmem from MPFuture. This disables the .cancel() feature but
|
||||||
@ -129,10 +132,9 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|||||||
self, timeout: Optional[float] = None, device: Optional[torch.device] = None
|
self, timeout: Optional[float] = None, device: Optional[torch.device] = None
|
||||||
) -> Tuple[Any, List[torch.Tensor]]:
|
) -> Tuple[Any, List[torch.Tensor]]:
|
||||||
"""receive next batch of arrays"""
|
"""receive next batch of arrays"""
|
||||||
|
device = device if device is not None else self.device
|
||||||
task = self._ordered_tasks.get(block=True, timeout=timeout)
|
task = self._ordered_tasks.get(block=True, timeout=timeout)
|
||||||
batch_inputs = [
|
batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]
|
||||||
tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
|
|
||||||
]
|
|
||||||
self._dispatched_tasks[task.uid] = task
|
self._dispatched_tasks[task.uid] = task
|
||||||
self.batch_receiver.recv() # reduce the number of active batches
|
self.batch_receiver.recv() # reduce the number of active batches
|
||||||
if not self._ordered_tasks.empty():
|
if not self._ordered_tasks.empty():
|
||||||
@ -142,11 +144,7 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|||||||
|
|
||||||
def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
|
def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
|
||||||
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
|
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
|
||||||
batch_outputs = [
|
batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs]
|
||||||
tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
|
|
||||||
for tensor in batch_outputs
|
|
||||||
]
|
|
||||||
|
|
||||||
task = self._dispatched_tasks.pop(uid, None)
|
task = self._dispatched_tasks.pop(uid, None)
|
||||||
if task is None:
|
if task is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -182,3 +180,13 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|||||||
assert len(item) == 2
|
assert len(item) == 2
|
||||||
self._priority.value = float(item[0])
|
self._priority.value = float(item[0])
|
||||||
self._oldest_undispatched_timestamp.value = float(item[1])
|
self._oldest_undispatched_timestamp.value = float(item[1])
|
||||||
|
|
||||||
|
|
||||||
|
def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False):
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad)
|
||||||
|
# note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor
|
||||||
|
# produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed)
|
||||||
|
if share_memory:
|
||||||
|
arg = arg.share_memory_()
|
||||||
|
return arg
|
||||||
|
@ -13,7 +13,10 @@ class TaskPrioritizerBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DummyTaskPrioritizer(TaskPrioritizerBase):
|
class DummyTaskPrioritizer(TaskPrioritizerBase):
|
||||||
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
|
|
||||||
|
|
||||||
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
|
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
|
||||||
return 0.0
|
# Inference steps (especially short ones) go first since they are more latency-sensitive
|
||||||
|
if kwargs.get("type") == "short_inference":
|
||||||
|
return 1.0
|
||||||
|
if kwargs.get("type") == "inference":
|
||||||
|
return 2.0
|
||||||
|
return 3.0 # Forward, backward
|
||||||
|
@ -1,21 +1,22 @@
|
|||||||
import fcntl
|
import fcntl
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from hashlib import sha256
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
from transformers import BloomConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from petals.bloom.block import WrappedBloomBlock
|
|
||||||
from petals.server.block_utils import resolve_block_dtype
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
from petals.utils.convert_8bit import replace_8bit_linear
|
from petals.utils.convert_block import QuantType, convert_block
|
||||||
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import speedtest
|
import speedtest
|
||||||
@ -31,21 +32,26 @@ if not hasattr(speedtest, "Speedtest"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_host_throughput(
|
def get_server_throughput(
|
||||||
config: BloomConfig,
|
model_name: str,
|
||||||
|
config: PretrainedConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: Union[str, torch.dtype],
|
dtype: Union[str, torch.dtype],
|
||||||
*,
|
*,
|
||||||
load_in_8bit: bool,
|
num_blocks: int,
|
||||||
|
quant_type: QuantType,
|
||||||
|
tensor_parallel_devices: Sequence[torch.device],
|
||||||
|
reachable_via_relay: bool,
|
||||||
|
relay_penalty: float = 0.2,
|
||||||
force_eval: bool = False,
|
force_eval: bool = False,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
) -> float:
|
) -> Dict[str, float]:
|
||||||
dtype = resolve_block_dtype(config, dtype)
|
dtype = resolve_block_dtype(config, dtype)
|
||||||
|
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = DEFAULT_CACHE_DIR
|
cache_dir = DEFAULT_CACHE_DIR
|
||||||
lock_path = Path(cache_dir, "throughput.lock")
|
lock_path = Path(cache_dir, "throughput.lock")
|
||||||
cache_path = Path(cache_dir, "throughput_v2.json")
|
cache_path = Path(cache_dir, "throughput_v5.json")
|
||||||
|
|
||||||
# We use the system-wide lock since only one process at a time can measure the host throughput
|
# We use the system-wide lock since only one process at a time can measure the host throughput
|
||||||
os.makedirs(lock_path.parent, exist_ok=True)
|
os.makedirs(lock_path.parent, exist_ok=True)
|
||||||
@ -54,9 +60,12 @@ def get_host_throughput(
|
|||||||
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
|
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
|
||||||
# The OS will release the lock when lock_fd is closed or the process is killed
|
# The OS will release the lock when lock_fd is closed or the process is killed
|
||||||
|
|
||||||
cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
|
cache_key = f"model_{model_name}"
|
||||||
cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
|
cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
|
||||||
cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
|
cache_key += f"_dtype_{get_dtype_name(dtype, quant_type)}"
|
||||||
|
if len(tensor_parallel_devices) > 1:
|
||||||
|
for i, device_i in enumerate(tensor_parallel_devices):
|
||||||
|
cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}"
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
try:
|
try:
|
||||||
@ -69,7 +78,9 @@ def get_host_throughput(
|
|||||||
cache = {}
|
cache = {}
|
||||||
|
|
||||||
if cache_key not in cache:
|
if cache_key not in cache:
|
||||||
cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit)
|
cache[cache_key] = measure_throughput_info(
|
||||||
|
config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(cache_path.parent, exist_ok=True)
|
os.makedirs(cache_path.parent, exist_ok=True)
|
||||||
@ -78,80 +89,143 @@ def get_host_throughput(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to save throughput info in {cache_path}")
|
logger.exception(f"Failed to save throughput info in {cache_path}")
|
||||||
|
|
||||||
return cache[cache_key]
|
throughput_info = cache[cache_key]
|
||||||
|
|
||||||
|
# Most requests start at some block hosted by a server, then use all next blocks hosted on this server.
|
||||||
|
# Assuming the start block index is distributed uniformly, the average number of blocks used per request is
|
||||||
|
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
|
||||||
|
average_blocks_used = (num_blocks + 1) / 2
|
||||||
|
throughput = throughput_info["forward_rps"] / average_blocks_used
|
||||||
|
|
||||||
|
network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1)
|
||||||
|
throughput = min(throughput, network_rps)
|
||||||
|
|
||||||
|
throughput_info["throughput"] = throughput
|
||||||
|
logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")
|
||||||
|
|
||||||
|
return throughput_info
|
||||||
|
|
||||||
|
|
||||||
def measure_throughput_info(
|
def measure_throughput_info(
|
||||||
config: BloomConfig,
|
config: PretrainedConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
*,
|
*,
|
||||||
load_in_8bit: bool,
|
quant_type: QuantType,
|
||||||
) -> float:
|
tensor_parallel_devices: Sequence[torch.device],
|
||||||
"""Measure network and compute throughput in forward pass tokens per second"""
|
) -> Dict[str, float]:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
|
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
|
||||||
)
|
)
|
||||||
return min(
|
return {
|
||||||
measure_network_rps(config),
|
"inference_rps": measure_compute_rps(
|
||||||
measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit),
|
config,
|
||||||
)
|
device,
|
||||||
|
dtype,
|
||||||
|
quant_type=quant_type,
|
||||||
|
tensor_parallel_devices=tensor_parallel_devices,
|
||||||
|
n_tokens=1,
|
||||||
|
n_steps=100,
|
||||||
|
inference=True,
|
||||||
|
),
|
||||||
|
"forward_rps": measure_compute_rps(
|
||||||
|
config,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
quant_type=quant_type,
|
||||||
|
tensor_parallel_devices=tensor_parallel_devices,
|
||||||
|
n_tokens=1024,
|
||||||
|
n_steps=10,
|
||||||
|
inference=False,
|
||||||
|
),
|
||||||
|
"network_rps": measure_network_rps(config),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def measure_network_rps(config: BloomConfig) -> float:
|
def measure_network_rps(
|
||||||
|
config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
|
||||||
|
) -> Optional[float]:
|
||||||
|
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
|
||||||
|
try:
|
||||||
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
||||||
|
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
|
||||||
|
process.start()
|
||||||
|
|
||||||
|
if not pipe_recv.poll(timeout):
|
||||||
|
process.terminate()
|
||||||
|
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
|
||||||
|
network_info = pipe_recv.recv()
|
||||||
|
if "exception" in network_info:
|
||||||
|
raise RuntimeError(f"speedtest failed: {network_info['exception']}")
|
||||||
|
|
||||||
|
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
|
||||||
|
if network_rps == 0:
|
||||||
|
raise RuntimeError("speedtest has returned network_rps == 0")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Network throughput: {network_rps:.1f} tokens/sec "
|
||||||
|
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
|
||||||
|
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
|
||||||
|
)
|
||||||
|
return network_rps
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s")
|
||||||
|
return default_speed / bits_per_request
|
||||||
|
|
||||||
|
|
||||||
|
def _measure_bits_per_second(pipe_send: mp.Pipe):
|
||||||
try:
|
try:
|
||||||
s = speedtest.Speedtest()
|
s = speedtest.Speedtest()
|
||||||
s.get_servers()
|
s.get_servers()
|
||||||
s.get_best_server()
|
s.get_best_server()
|
||||||
s.download()
|
s.download()
|
||||||
s.upload()
|
s.upload()
|
||||||
network_info = s.results.dict()
|
pipe_send.send(s.results.dict())
|
||||||
except:
|
except Exception as e:
|
||||||
logger.error("Failed to measure network throughput:")
|
pipe_send.send({"exception": repr(e)})
|
||||||
raise
|
|
||||||
|
|
||||||
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
|
|
||||||
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Network throughput: "
|
|
||||||
f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
|
|
||||||
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
|
|
||||||
f"{network_rps:.1f} RPS"
|
|
||||||
)
|
|
||||||
return network_rps
|
|
||||||
|
|
||||||
|
|
||||||
def measure_compute_rps(
|
def measure_compute_rps(
|
||||||
config: BloomConfig,
|
config: PretrainedConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
*,
|
*,
|
||||||
load_in_8bit: bool,
|
quant_type: QuantType,
|
||||||
n_tokens: int = 16,
|
tensor_parallel_devices: Sequence[torch.device],
|
||||||
n_steps: int = 500,
|
n_tokens: int,
|
||||||
|
n_steps: int,
|
||||||
|
inference: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
device = torch.device(device)
|
||||||
|
if not tensor_parallel_devices:
|
||||||
|
tensor_parallel_devices = (device,)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
block = WrappedBloomBlock(config).to(dtype)
|
block = config.block_class(config).to(dtype)
|
||||||
if load_in_8bit:
|
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
||||||
block = replace_8bit_linear(block)
|
|
||||||
block = block.to(device)
|
|
||||||
|
|
||||||
cache = None
|
cache = None
|
||||||
elapsed = 0
|
elapsed = 0
|
||||||
for step in range(n_steps + 1):
|
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
|
||||||
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
|
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.synchronize(device)
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache)
|
for step in range(n_steps):
|
||||||
if step >= 1: # Skip the 1st step to exclude the initialization time
|
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
|
||||||
elapsed += time.perf_counter() - start_time
|
if device.type == "cuda":
|
||||||
|
torch.cuda.synchronize(device)
|
||||||
|
elapsed = time.perf_counter() - start_time
|
||||||
device_rps = n_steps * n_tokens / elapsed
|
device_rps = n_steps * n_tokens / elapsed
|
||||||
|
|
||||||
|
devices_repr = get_device_name(device)
|
||||||
|
if len(tensor_parallel_devices) > 1:
|
||||||
|
device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices)))
|
||||||
|
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): "
|
f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block "
|
||||||
f"{device_rps:.1f} RPS"
|
f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})"
|
||||||
)
|
)
|
||||||
return device_rps
|
return device_rps
|
||||||
|
|
||||||
@ -160,5 +234,8 @@ def get_device_name(device: torch.device) -> str:
|
|||||||
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
|
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
|
||||||
|
|
||||||
|
|
||||||
def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
|
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
|
||||||
return "8-bit" if load_in_8bit else str(dtype)
|
name = str(dtype).replace("torch.", "")
|
||||||
|
if quant_type != QuantType.NONE:
|
||||||
|
name += f", quantized to {quant_type.name.lower()}"
|
||||||
|
return name
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
from petals.utils.auto_config import (
|
||||||
|
AutoDistributedConfig,
|
||||||
|
AutoDistributedModel,
|
||||||
|
AutoDistributedModelForCausalLM,
|
||||||
|
AutoDistributedModelForSequenceClassification,
|
||||||
|
)
|
65
src/petals/utils/auto_config.py
Normal file
65
src/petals/utils/auto_config.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
|
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
|
from petals.utils.hf_auth import always_needs_auth
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ModelClasses:
|
||||||
|
config: Type[PretrainedConfig]
|
||||||
|
model: Optional[Type[PreTrainedModel]] = None
|
||||||
|
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
|
||||||
|
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
|
||||||
|
|
||||||
|
|
||||||
|
_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes()
|
||||||
|
|
||||||
|
|
||||||
|
def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
|
||||||
|
assert issubclass(config, PretrainedConfig)
|
||||||
|
assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
|
||||||
|
|
||||||
|
_CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _AutoDistributedBase:
|
||||||
|
_mapping_field = None # Should be defined in child classes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
|
||||||
|
if (
|
||||||
|
always_needs_auth(model_name_or_path)
|
||||||
|
and kwargs.get("token") is None
|
||||||
|
and kwargs.get("use_auth_token") is None
|
||||||
|
):
|
||||||
|
kwargs["use_auth_token"] = True
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||||
|
if config.model_type not in _CLASS_MAPPING:
|
||||||
|
raise ValueError(f"Petals does not support model type {config.model_type}")
|
||||||
|
|
||||||
|
proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
|
||||||
|
if proper_cls is None:
|
||||||
|
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
|
||||||
|
|
||||||
|
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedConfig(_AutoDistributedBase):
|
||||||
|
_mapping_field = "config"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedModel(_AutoDistributedBase):
|
||||||
|
_mapping_field = "model"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
|
||||||
|
_mapping_field = "model_for_causal_lm"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
|
||||||
|
_mapping_field = "model_for_sequence_classification"
|
@ -1,39 +0,0 @@
|
|||||||
import bitsandbytes as bnb
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
|
|
||||||
|
|
||||||
|
|
||||||
def replace_8bit_linear(model, threshold=6.0):
|
|
||||||
"""
|
|
||||||
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
|
||||||
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
|
|
||||||
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
|
|
||||||
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
|
|
||||||
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
|
|
||||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
|
|
||||||
be kept as a `torch.nn.Linear` module.
|
|
||||||
Parameters:
|
|
||||||
model (`torch.nn.Module`):
|
|
||||||
Input model or `torch.nn.Module` as the function is run recursively.
|
|
||||||
threshold (`float`, *optional*):
|
|
||||||
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
|
|
||||||
`6.0` as described by the paper.
|
|
||||||
"""
|
|
||||||
for n, module in model.named_children():
|
|
||||||
if len(list(module.children())) > 0:
|
|
||||||
replace_8bit_linear(module, threshold)
|
|
||||||
|
|
||||||
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
|
||||||
model._modules[n] = CustomLinear8bitLt(
|
|
||||||
module.in_features,
|
|
||||||
module.out_features,
|
|
||||||
module.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=threshold,
|
|
||||||
)
|
|
||||||
model._modules[n].weight = bnb.nn.Int8Params(
|
|
||||||
module.weight.data, requires_grad=False, has_fp16_weights=False
|
|
||||||
).to(module.weight.dtype)
|
|
||||||
model._modules[n].bias = module.bias
|
|
||||||
return model
|
|
156
src/petals/utils/convert_block.py
Normal file
156
src/petals/utils/convert_block.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import tensor_parallel as tp
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||||
|
from tensor_parallel.slicing_configs import get_bloom_config
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
use_hivemind_log_handler("in_root_logger")
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantType(Enum):
|
||||||
|
NONE = 0
|
||||||
|
INT8 = 1 # 8-bit as in the LLM.int8() paper
|
||||||
|
NF4 = 2 # 4-bit as in the QLoRA paper
|
||||||
|
|
||||||
|
|
||||||
|
def convert_block(
|
||||||
|
block: nn.Module,
|
||||||
|
block_index: int,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
tensor_parallel_devices: Sequence[torch.device],
|
||||||
|
output_device: torch.device,
|
||||||
|
quant_type: QuantType,
|
||||||
|
freeze: bool = True,
|
||||||
|
adapters: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tp.TensorParallel:
|
||||||
|
"""
|
||||||
|
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
|
||||||
|
|
||||||
|
:note: some optimizations will modify the input block in-place!
|
||||||
|
:param block: a single transformer block, either pre-trained or newly initialized
|
||||||
|
:param config: HF transformers config for the full model
|
||||||
|
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
|
||||||
|
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
|
||||||
|
:param output_device: if tensor_parallel_devices is True, output
|
||||||
|
:param quant_type: quantization type
|
||||||
|
:param freeze: if True (default), make all module parameters non-trainable
|
||||||
|
:return: a module that acts like the original block, but runs with all specified optimizations
|
||||||
|
|
||||||
|
"""
|
||||||
|
if freeze:
|
||||||
|
block.requires_grad_(False)
|
||||||
|
|
||||||
|
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
|
||||||
|
|
||||||
|
if quant_type != QuantType.NONE:
|
||||||
|
block = quantize_module(block, quant_type=quant_type)
|
||||||
|
|
||||||
|
for shard, device in zip(block.module_shards, block.devices):
|
||||||
|
shard.to(device)
|
||||||
|
|
||||||
|
if adapters:
|
||||||
|
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
|
||||||
|
|
||||||
|
create_lora_adapter(block, quant_type=quant_type)
|
||||||
|
for adapter_name in adapters:
|
||||||
|
adapter_config, adapter_state_dict = load_peft(
|
||||||
|
adapter_name,
|
||||||
|
block_idx=block_index,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
|
||||||
|
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
|
||||||
|
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
for n, module in model.named_children():
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
quantize_module(module, quant_type=quant_type)
|
||||||
|
|
||||||
|
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
||||||
|
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
|
||||||
|
if quant_type == QuantType.INT8:
|
||||||
|
model._modules[n] = bnb.nn.Linear8bitLt(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=6.0, # Default from the LLM.int8() paper
|
||||||
|
)
|
||||||
|
model._modules[n].weight = bnb.nn.Int8Params(
|
||||||
|
module.weight.data, requires_grad=False, has_fp16_weights=False
|
||||||
|
).to(module.weight.dtype)
|
||||||
|
elif quant_type == QuantType.NF4:
|
||||||
|
compress_statistics = True
|
||||||
|
model._modules[n] = bnb.nn.LinearNF4(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
compress_statistics=compress_statistics,
|
||||||
|
)
|
||||||
|
model._modules[n].weight = bnb.nn.Params4bit(
|
||||||
|
module.weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quant_type="nf4",
|
||||||
|
blocksize=64,
|
||||||
|
compress_statistics=compress_statistics,
|
||||||
|
).to(module.weight.dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quant_type='{quant_type}'")
|
||||||
|
model._modules[n].bias = module.bias
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def make_tensor_parallel(
|
||||||
|
block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device
|
||||||
|
) -> nn.Module:
|
||||||
|
if model_config.model_type == "bloom":
|
||||||
|
tp_config = get_bloom_config(model_config, devices)
|
||||||
|
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
|
||||||
|
else:
|
||||||
|
if len(devices) > 1:
|
||||||
|
logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
|
||||||
|
tp_config = None
|
||||||
|
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
|
||||||
|
total_heads = 0
|
||||||
|
for tp_shard in tp_block.module_shards:
|
||||||
|
for submodule in tp_shard.modules():
|
||||||
|
if isinstance(submodule, model_config.attn_class):
|
||||||
|
total_heads += submodule.num_heads
|
||||||
|
assert total_heads == model_config.num_attention_heads
|
||||||
|
return tp_block
|
||||||
|
|
||||||
|
|
||||||
|
def check_device_balance(devices: Sequence[torch.device]):
|
||||||
|
if not all(device.type == "cuda" for device in devices):
|
||||||
|
logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
|
||||||
|
return
|
||||||
|
unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
|
||||||
|
if len(unique_device_capabilities) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
|
||||||
|
f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
|
||||||
|
used_memory = min(memory_per_device) * len(memory_per_device)
|
||||||
|
wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
|
||||||
|
if wasted_memory_rate > 0.05:
|
||||||
|
logger.warning(
|
||||||
|
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
|
||||||
|
f"Consider running high-memory GPUs in a separate server."
|
||||||
|
)
|
@ -8,7 +8,7 @@ from typing import Optional
|
|||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
|
DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
|
||||||
|
|
||||||
@ -33,15 +33,12 @@ def allow_cache_reads(cache_dir: Optional[str]):
|
|||||||
return _blocks_lock(cache_dir, fcntl.LOCK_SH)
|
return _blocks_lock(cache_dir, fcntl.LOCK_SH)
|
||||||
|
|
||||||
|
|
||||||
def allow_cache_writes(
|
def allow_cache_writes(cache_dir: Optional[str]):
|
||||||
cache_dir: Optional[str], *, reserve: Optional[int] = None, max_disk_space: Optional[int] = None
|
|
||||||
):
|
|
||||||
"""Allows saving new blocks and removing the old ones (exclusive lock)"""
|
"""Allows saving new blocks and removing the old ones (exclusive lock)"""
|
||||||
return _blocks_lock(cache_dir, fcntl.LOCK_EX)
|
return _blocks_lock(cache_dir, fcntl.LOCK_EX)
|
||||||
|
|
||||||
|
|
||||||
def free_disk_space_for(
|
def free_disk_space_for(
|
||||||
model_name: str,
|
|
||||||
size: int,
|
size: int,
|
||||||
*,
|
*,
|
||||||
cache_dir: Optional[str],
|
cache_dir: Optional[str],
|
||||||
@ -51,36 +48,36 @@ def free_disk_space_for(
|
|||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = DEFAULT_CACHE_DIR
|
cache_dir = DEFAULT_CACHE_DIR
|
||||||
cache_info = huggingface_hub.scan_cache_dir(cache_dir)
|
cache_info = huggingface_hub.scan_cache_dir(cache_dir)
|
||||||
model_repos = [repo for repo in cache_info.repos if repo.repo_type == "model" and repo.repo_id == model_name]
|
|
||||||
|
|
||||||
occupied_space = sum(repo.size_on_disk for repo in model_repos)
|
|
||||||
available_space = shutil.disk_usage(cache_dir).free - os_quota
|
available_space = shutil.disk_usage(cache_dir).free - os_quota
|
||||||
if max_disk_space is not None:
|
if max_disk_space is not None:
|
||||||
available_space = min(available_space, max_disk_space - occupied_space)
|
available_space = min(available_space, max_disk_space - cache_info.size_on_disk)
|
||||||
|
|
||||||
|
gib = 1024**3
|
||||||
|
logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB")
|
||||||
if size <= available_space:
|
if size <= available_space:
|
||||||
return
|
return
|
||||||
|
|
||||||
revisions = [revision for repo in model_repos for revision in repo.revisions]
|
cached_files = [file for repo in cache_info.repos for revision in repo.revisions for file in revision.files]
|
||||||
revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
|
|
||||||
|
|
||||||
# Remove as few least recently used blocks as possible
|
# Remove as few least recently used files as possible
|
||||||
pending_removal = []
|
removed_files = []
|
||||||
freed_space = 0
|
freed_space = 0
|
||||||
extra_space_needed = size - available_space
|
extra_space_needed = size - available_space
|
||||||
for rev in revisions:
|
for file in sorted(cached_files, key=lambda file: file.blob_last_accessed):
|
||||||
pending_removal.append(rev.commit_hash)
|
os.remove(file.file_path) # Remove symlink
|
||||||
freed_space += rev.size_on_disk
|
os.remove(file.blob_path) # Remove contents
|
||||||
|
|
||||||
|
removed_files.append(file)
|
||||||
|
freed_space += file.size_on_disk
|
||||||
if freed_space >= extra_space_needed:
|
if freed_space >= extra_space_needed:
|
||||||
break
|
break
|
||||||
|
if removed_files:
|
||||||
if pending_removal:
|
logger.info(f"Removed {len(removed_files)} files to free {freed_space / gib:.1f} GiB of disk space")
|
||||||
gib = 1024**3
|
logger.debug(f"Removed paths: {[str(file.file_path) for file in removed_files]}")
|
||||||
logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space")
|
|
||||||
delete_strategy = cache_info.delete_revisions(*pending_removal)
|
|
||||||
delete_strategy.execute()
|
|
||||||
|
|
||||||
if freed_space < extra_space_needed:
|
if freed_space < extra_space_needed:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Insufficient disk space to load a block. Please free {extra_space_needed - freed_space:.1f} GiB "
|
f"Insufficient disk space to load a block. Please free {(extra_space_needed - freed_space) / gib:.1f} GiB "
|
||||||
f"on the volume for {cache_dir} or increase --max_disk_space if you set it manually"
|
f"on the volume for {cache_dir} or increase --max_disk_space if you set it manually"
|
||||||
)
|
)
|
||||||
|
@ -16,7 +16,7 @@ class DecodingAlgorithm(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||||
"""
|
"""
|
||||||
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
|
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
|
||||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
||||||
"""
|
"""
|
||||||
@ -99,7 +99,6 @@ class RepetitionPenaltyAlgorithm(SamplingAlgorithm):
|
|||||||
class BeamSearchAlgorithm(DecodingAlgorithm):
|
class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||||
def __init__(self, num_beams: int, batch_size: int) -> None:
|
def __init__(self, num_beams: int, batch_size: int) -> None:
|
||||||
self.num_beams = num_beams
|
self.num_beams = num_beams
|
||||||
self._cur_num_beams = 1
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
self._batch_beams = [list() for _ in range(batch_size)]
|
self._batch_beams = [list() for _ in range(batch_size)]
|
||||||
|
7
src/petals/utils/hf_auth.py
Normal file
7
src/petals/utils/hf_auth.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
|
||||||
|
loading_from_repo = model_name is not None and not os.path.isdir(model_name)
|
||||||
|
return loading_from_repo and model_name.startswith("meta-llama/Llama-2-")
|
@ -1,334 +0,0 @@
|
|||||||
"""
|
|
||||||
A patch to bitsandbytes 0.34.0 that introduces an option to run backward pass in default (fast) matrix layout.
|
|
||||||
Authors: modification by @borzunov, original code by @timdettmers. Please disregard commit authors in this file.
|
|
||||||
|
|
||||||
Core idea: layouts apply the same permutation to every tile in the matrix. We can treat this as (batched) gather ops.
|
|
||||||
Reshape input tensor so that ij-th gather operation op will apply to ij-th elements in each tile.
|
|
||||||
Prototype: https://colab.research.google.com/drive/1EJ0MKifajXSSVq7O2_QGwtb0l6gRAGrh?usp=sharing
|
|
||||||
Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#L2130-L2136
|
|
||||||
Exact match tests: see $REPO/tests/test_linear8bitlt.py
|
|
||||||
"""
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import bitsandbytes.functional as F
|
|
||||||
import torch
|
|
||||||
from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatMul8bitLt, MatmulLtState, prod
|
|
||||||
from bitsandbytes.nn import Linear8bitLt
|
|
||||||
|
|
||||||
|
|
||||||
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
|
|
||||||
"""
|
|
||||||
Compute a permutation of indices that invert the specified (tiled) matrix transformation
|
|
||||||
|
|
||||||
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
|
|
||||||
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
|
|
||||||
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
|
|
||||||
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
|
|
||||||
:returns: indices
|
|
||||||
"""
|
|
||||||
d1, d2 = tile_size
|
|
||||||
assert 0 < d1 * d2 < 2**64
|
|
||||||
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
|
|
||||||
# encode each position in tile as a tuple of <= 8 unique bytes
|
|
||||||
permuted_tile_indices = torch.zeros_like(tile_indices)
|
|
||||||
for i in range(8):
|
|
||||||
# select i-th byte, apply transformation and trace where each index ended up
|
|
||||||
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
|
|
||||||
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
|
|
||||||
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
|
|
||||||
permuted_tile_i = transform_tile(sample_tile_i)
|
|
||||||
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
|
|
||||||
permuted_tile_indices += ith_permuted_indices * (256**i)
|
|
||||||
if d1 * d2 < 256**i:
|
|
||||||
break # if all indices fit in i bytes, stop early
|
|
||||||
return permuted_tile_indices
|
|
||||||
|
|
||||||
|
|
||||||
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Undo a tiled permutation such as turing or ampere layout
|
|
||||||
|
|
||||||
:param permuted_tensor: torch tensor in a permuted layout
|
|
||||||
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
|
|
||||||
:return: contiguous row-major tensor
|
|
||||||
"""
|
|
||||||
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
|
|
||||||
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
|
|
||||||
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
|
|
||||||
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
|
|
||||||
outputs[tile_indices.flatten()] = tensor
|
|
||||||
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
|
|
||||||
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
|
|
||||||
return outputs.reshape(rows, cols).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
# the rest of this file is just a patch to bitsandbytes that modifies Linear8bitLt and dependencies
|
|
||||||
|
|
||||||
|
|
||||||
class CustomLinear8bitLt(Linear8bitLt):
|
|
||||||
def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs):
|
|
||||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer used"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
old_state, self.state = self.state, CustomMatmulLtState()
|
|
||||||
self.state.threshold = old_state.threshold
|
|
||||||
self.state.has_fp16_weights = old_state.has_fp16_weights
|
|
||||||
self.state.memory_efficient_backward = old_state.memory_efficient_backward
|
|
||||||
if old_state.threshold > 0.0 and not old_state.has_fp16_weights:
|
|
||||||
self.state.use_pool = True
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
self.state.is_training = self.training
|
|
||||||
if self.weight.CB is not None:
|
|
||||||
self.init_8bit_state()
|
|
||||||
|
|
||||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
|
||||||
|
|
||||||
out = custom_matmul8bitlt(x, self.weight, bias=self.bias, state=self.state)
|
|
||||||
if not self.state.has_fp16_weights:
|
|
||||||
if self.state.CB is not None and self.state.CxB is not None:
|
|
||||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
|
||||||
# we no longer need the row-major weight
|
|
||||||
del self.state.CB
|
|
||||||
self.weight.data = self.state.CxB
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(init=True)
|
|
||||||
class CustomMatmulLtState(MatmulLtState):
|
|
||||||
tile_indices: Optional[torch.Tensor] = None
|
|
||||||
force_no_igemmlt: bool = False
|
|
||||||
|
|
||||||
def get_tile_size(self):
|
|
||||||
assert self.formatB in (
|
|
||||||
"col_turing",
|
|
||||||
"col_ampere",
|
|
||||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
|
||||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
|
||||||
|
|
||||||
|
|
||||||
def custom_matmul8bitlt(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
out: torch.Tensor = None,
|
|
||||||
state: CustomMatmulLtState = None,
|
|
||||||
threshold=0.0,
|
|
||||||
bias=None,
|
|
||||||
):
|
|
||||||
state = state or MatmulLtState()
|
|
||||||
if threshold > 0.0:
|
|
||||||
state.threshold = threshold
|
|
||||||
return CustomMatMul8bitLt.apply(A, B, out, bias, state)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMatMul8bitLt(MatMul8bitLt):
|
|
||||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
|
||||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, A, B, out=None, bias=None, state=CustomMatmulLtState):
|
|
||||||
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
|
||||||
# default to pytorch behavior if inputs are empty
|
|
||||||
ctx.is_empty = False
|
|
||||||
if prod(A.shape) == 0:
|
|
||||||
ctx.is_empty = True
|
|
||||||
ctx.A = A
|
|
||||||
ctx.B = B
|
|
||||||
ctx.bias = bias
|
|
||||||
if A.shape[-1] == B.shape[0]:
|
|
||||||
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
|
|
||||||
else:
|
|
||||||
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
|
|
||||||
|
|
||||||
# 1. Quantize A
|
|
||||||
# 2. Quantize B
|
|
||||||
# 3. Matmul
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
|
||||||
# 5. Save state
|
|
||||||
formatB = state.formatB
|
|
||||||
input_shape = A.shape
|
|
||||||
if state.outlier_pool is None:
|
|
||||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
|
||||||
|
|
||||||
# Cast A to fp16
|
|
||||||
if A.dtype != torch.float16:
|
|
||||||
logging.debug(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
|
||||||
|
|
||||||
# 1. Quantize A
|
|
||||||
if len(A.shape) == 3:
|
|
||||||
A = A.view(-1, A.shape[-1]).contiguous()
|
|
||||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
|
|
||||||
|
|
||||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
|
||||||
if state.has_fp16_weights:
|
|
||||||
idx = torch.unique(coo_tensorA.colidx).long()
|
|
||||||
CA[:, idx] = 0
|
|
||||||
CAt[:, idx] = 0
|
|
||||||
subA = A[:, idx]
|
|
||||||
state.subB = B[:, idx].t().contiguous()
|
|
||||||
state.idx = idx
|
|
||||||
else:
|
|
||||||
if state.CxB is None and using_igemmlt:
|
|
||||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
|
||||||
# we also need to convert it to the turing/ampere format
|
|
||||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
|
||||||
else:
|
|
||||||
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
|
|
||||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
|
||||||
subA = None
|
|
||||||
|
|
||||||
# 2. Quantize B
|
|
||||||
if state.has_fp16_weights:
|
|
||||||
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
|
||||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
|
||||||
if is_transposed:
|
|
||||||
B = B.contiguous()
|
|
||||||
|
|
||||||
if (state.is_training and not has_grad) or state.CxB is None:
|
|
||||||
state.reset_grads()
|
|
||||||
(
|
|
||||||
CB,
|
|
||||||
state.CBt,
|
|
||||||
state.SCB,
|
|
||||||
state.SCBt,
|
|
||||||
coo_tensorB,
|
|
||||||
) = F.double_quant(B.to(torch.float16))
|
|
||||||
if using_igemmlt:
|
|
||||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
|
||||||
else:
|
|
||||||
state.CB = CB
|
|
||||||
else:
|
|
||||||
has_grad = False
|
|
||||||
|
|
||||||
if coo_tensorA is not None and not state.has_fp16_weights:
|
|
||||||
# extract outliers
|
|
||||||
|
|
||||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
|
||||||
state.idx = outlier_idx
|
|
||||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
|
||||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
|
||||||
# # do not use pool for 2nd FFN layer
|
|
||||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
|
||||||
# else:
|
|
||||||
# state.idx = outlier_idx
|
|
||||||
if state.CxB is not None:
|
|
||||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
|
||||||
else:
|
|
||||||
outliers = state.CB[:, state.idx.long()].clone()
|
|
||||||
|
|
||||||
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
|
|
||||||
CA[:, state.idx.long()] = 0
|
|
||||||
CAt[:, state.idx.long()] = 0
|
|
||||||
subA = A[:, state.idx.long()]
|
|
||||||
|
|
||||||
shapeB = state.SB[0] if state.SB else B.shape
|
|
||||||
|
|
||||||
if len(input_shape) == 3:
|
|
||||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
|
||||||
else:
|
|
||||||
output_shape = (input_shape[0], shapeB[0])
|
|
||||||
|
|
||||||
# 3. Matmul
|
|
||||||
if using_igemmlt:
|
|
||||||
C32A, SA = F.transform(CA, "col32")
|
|
||||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
|
||||||
if bias is None or bias.dtype == torch.float16:
|
|
||||||
# we apply the fused bias here
|
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
|
||||||
output = output.to(A.dtype)
|
|
||||||
else: # apply bias separately
|
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
|
||||||
output = output.to(A.dtype).add_(bias)
|
|
||||||
|
|
||||||
else:
|
|
||||||
A_wo_outliers = A.clone()
|
|
||||||
if state.idx is not None:
|
|
||||||
A_wo_outliers[:, state.idx.long()] = 0
|
|
||||||
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
|
|
||||||
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
|
|
||||||
if bias is not None:
|
|
||||||
output = output.add_(bias)
|
|
||||||
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
|
||||||
if coo_tensorA is not None and subA is not None:
|
|
||||||
output += torch.matmul(subA, state.subB)
|
|
||||||
|
|
||||||
# 5. Save state
|
|
||||||
ctx.state = state
|
|
||||||
|
|
||||||
ctx.formatB = formatB
|
|
||||||
ctx.grad_shape = input_shape
|
|
||||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
|
||||||
|
|
||||||
if any(ctx.needs_input_grad[:2]):
|
|
||||||
ctx.tensors = (CAt, subA)
|
|
||||||
ctx.tensor_states = (SCAt, state.idx)
|
|
||||||
else:
|
|
||||||
ctx.tensors = [None, None]
|
|
||||||
ctx.tensor_states = (None, None)
|
|
||||||
ctx.save_for_backward(None, None)
|
|
||||||
|
|
||||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
|
|
||||||
return clone_func(output.view(output_shape))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
if ctx.is_empty:
|
|
||||||
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
|
||||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
|
||||||
CAt, subA = ctx.tensors
|
|
||||||
SCAt, idx = ctx.tensor_states
|
|
||||||
formatB = ctx.formatB
|
|
||||||
state = ctx.state
|
|
||||||
grad_A = grad_B = grad_bias = None
|
|
||||||
|
|
||||||
if req_gradBias:
|
|
||||||
# compute grad_bias first before changing grad_output dtype
|
|
||||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
|
||||||
|
|
||||||
# Cast grad_output to fp16
|
|
||||||
if len(grad_output.shape) == 3:
|
|
||||||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
|
||||||
|
|
||||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
|
||||||
if req_gradB:
|
|
||||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
|
||||||
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
|
||||||
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
|
||||||
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
|
|
||||||
if state.threshold > 0.0 and subA is not None:
|
|
||||||
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
|
||||||
|
|
||||||
if req_gradA:
|
|
||||||
if state.CBt is not None:
|
|
||||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
|
||||||
if state.CxBt is None:
|
|
||||||
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
|
|
||||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
|
||||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
|
||||||
|
|
||||||
elif state.CB is not None:
|
|
||||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
|
||||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
|
||||||
elif state.CxB is not None:
|
|
||||||
|
|
||||||
if state.tile_indices is None:
|
|
||||||
order, tile_size = state.formatB, state.get_tile_size()
|
|
||||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
|
||||||
|
|
||||||
CB = (
|
|
||||||
undo_layout(state.CxB, state.tile_indices)
|
|
||||||
.to(ctx.dtype_A)
|
|
||||||
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
|
||||||
)
|
|
||||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
|
||||||
else:
|
|
||||||
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
|
|
||||||
|
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
|
@ -1,19 +1,8 @@
|
|||||||
import importlib
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from hivemind.utils import logging as hm_logging
|
from hivemind.utils import logging as hm_logging
|
||||||
|
|
||||||
|
|
||||||
def in_jupyter() -> bool:
|
|
||||||
"""Check if the code is run in Jupyter or Colab"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
__IPYTHON__
|
|
||||||
return True
|
|
||||||
except NameError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_logs():
|
def initialize_logs():
|
||||||
"""Initialize Petals logging tweaks. This function is called when you import the `petals` module."""
|
"""Initialize Petals logging tweaks. This function is called when you import the `petals` module."""
|
||||||
|
|
||||||
@ -21,14 +10,6 @@ def initialize_logs():
|
|||||||
if os.getenv("PETALS_LOGGING", "True").lower() in ("false", "0"):
|
if os.getenv("PETALS_LOGGING", "True").lower() in ("false", "0"):
|
||||||
return
|
return
|
||||||
|
|
||||||
if in_jupyter():
|
|
||||||
os.environ["HIVEMIND_COLORS"] = "True"
|
|
||||||
importlib.reload(hm_logging)
|
|
||||||
|
|
||||||
# Remove log handlers from previous import of hivemind.utils.logging and extra handlers on Colab
|
|
||||||
hm_logging.get_logger().handlers.clear()
|
|
||||||
hm_logging.get_logger("hivemind").handlers.clear()
|
|
||||||
|
|
||||||
hm_logging.use_hivemind_log_handler("in_root_logger")
|
hm_logging.use_hivemind_log_handler("in_root_logger")
|
||||||
|
|
||||||
# We suppress asyncio error logs by default since they are mostly not relevant for the end user,
|
# We suppress asyncio error logs by default since they are mostly not relevant for the end user,
|
||||||
|
288
src/petals/utils/peft.py
Normal file
288
src/petals/utils/peft.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
import contextlib
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
||||||
|
from peft.tuners import lora
|
||||||
|
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from transformers.utils import get_file_from_repo
|
||||||
|
|
||||||
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.utils.convert_block import QuantType
|
||||||
|
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_peft_repository(repo_id: str) -> bool:
|
||||||
|
fs = HfFileSystem()
|
||||||
|
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
|
||||||
|
return len(list_of_files) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
||||||
|
tensors = dict()
|
||||||
|
is_tensors_found = dict()
|
||||||
|
common_layer_patter_re = (
|
||||||
|
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
||||||
|
)
|
||||||
|
with safe_open(filepath, framework=framework, device=device) as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if re.match(common_layer_patter_re, k):
|
||||||
|
is_tensors_found[block_idx] = True
|
||||||
|
tensors[k] = f.get_tensor(k)
|
||||||
|
if not is_tensors_found.get(block_idx, False):
|
||||||
|
logger.warning(f"There is no peft weights for block {block_idx}")
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
def get_adapter_from_repo(
|
||||||
|
repo_id: str,
|
||||||
|
block_idx: Optional[int] = None,
|
||||||
|
device: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
|
||||||
|
if config_path is None:
|
||||||
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
||||||
|
config = PeftConfig.from_json_file(config_path)
|
||||||
|
|
||||||
|
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
|
||||||
|
if weight_path is None:
|
||||||
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
||||||
|
if block_idx is None:
|
||||||
|
return config, load_file(weight_path)
|
||||||
|
return config, load_specific_module(block_idx, weight_path, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def load_peft(
|
||||||
|
repo_id: str,
|
||||||
|
block_idx: Optional[int] = None,
|
||||||
|
device: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
cache_dir: str,
|
||||||
|
max_disk_space: Optional[int] = None,
|
||||||
|
delay: float = 30,
|
||||||
|
):
|
||||||
|
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
||||||
|
|
||||||
|
if not check_peft_repository(repo_id):
|
||||||
|
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with allow_cache_reads(cache_dir):
|
||||||
|
return get_adapter_from_repo(
|
||||||
|
repo_id,
|
||||||
|
block_idx,
|
||||||
|
device,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with allow_cache_writes(cache_dir):
|
||||||
|
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
||||||
|
config_file_size = get_hf_file_metadata(config_url, token=token).size
|
||||||
|
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||||
|
weight_file_size = get_hf_file_metadata(weight_url, token=token).size
|
||||||
|
|
||||||
|
file_size = config_file_size + weight_file_size
|
||||||
|
if file_size is not None:
|
||||||
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
||||||
|
|
||||||
|
return get_adapter_from_repo(
|
||||||
|
repo_id,
|
||||||
|
block_idx,
|
||||||
|
device,
|
||||||
|
revision=revision,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterContextMixin:
|
||||||
|
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
|
||||||
|
|
||||||
|
ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
|
||||||
|
_context_active_adapter = ADAPTER_NOT_SET
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def using_adapter(active_adapter: Optional[str]):
|
||||||
|
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
AdapterContextMixin._context_active_adapter = prev
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_adapter(self):
|
||||||
|
if self._context_active_adapter == self.ADAPTER_NOT_SET:
|
||||||
|
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
|
||||||
|
return self._context_active_adapter
|
||||||
|
|
||||||
|
@active_adapter.setter
|
||||||
|
def active_adapter(self, value: Optional[str]):
|
||||||
|
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
|
||||||
|
|
||||||
|
|
||||||
|
using_adapter = AdapterContextMixin.using_adapter
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear(lora.Linear, AdapterContextMixin):
|
||||||
|
"""LoRA linear layer that uses adapter selected via using_adapter"""
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
|
||||||
|
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
|
||||||
|
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_adapter(block, quant_type: QuantType):
|
||||||
|
for _, module in block.named_modules():
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
lora_wrapped_child = None
|
||||||
|
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
||||||
|
continue
|
||||||
|
if quant_type == QuantType.INT8:
|
||||||
|
kwargs = {
|
||||||
|
"has_fp16_weights": False,
|
||||||
|
"threshold": 6.0,
|
||||||
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||||
|
}
|
||||||
|
lora_wrapped_child = LoraLinear8bitLt(
|
||||||
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif quant_type == QuantType.NF4:
|
||||||
|
kwargs = {
|
||||||
|
"compress_statistics": True,
|
||||||
|
"quant_type": "nf4",
|
||||||
|
"blocksize": 64,
|
||||||
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||||
|
}
|
||||||
|
lora_wrapped_child = LoraLinear4bit(
|
||||||
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
lora_wrapped_child.compute_dtype = child.compute_dtype
|
||||||
|
else:
|
||||||
|
bias = hasattr(child, "bias") and child.bias is not None
|
||||||
|
lora_wrapped_child = LoraLinear(
|
||||||
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
if lora_wrapped_child:
|
||||||
|
lora_wrapped_child.weight = child.weight
|
||||||
|
lora_wrapped_child.bias = child.bias
|
||||||
|
for p in lora_wrapped_child.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
setattr(module, child_name, lora_wrapped_child)
|
||||||
|
|
||||||
|
|
||||||
|
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
||||||
|
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
||||||
|
if peft_config["lora_dropout"] > 0:
|
||||||
|
logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout")
|
||||||
|
|
||||||
|
for _, module in block.named_modules():
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if child_name in peft_config["target_modules"] or (
|
||||||
|
isinstance(peft_config["target_modules"], str)
|
||||||
|
and re.fullmatch(peft_config["target_modules"], child_name)
|
||||||
|
):
|
||||||
|
is_lora_a_loaded = False
|
||||||
|
is_lora_b_loaded = False
|
||||||
|
for peft_key in peft_state_dict:
|
||||||
|
if child_name not in peft_key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if adapter_name not in child.lora_A:
|
||||||
|
child.update_layer(
|
||||||
|
adapter_name,
|
||||||
|
peft_config["r"],
|
||||||
|
peft_config["lora_alpha"],
|
||||||
|
lora_dropout=peft_config["lora_dropout"],
|
||||||
|
init_lora_weights=peft_config["init_lora_weights"],
|
||||||
|
)
|
||||||
|
child.train(False)
|
||||||
|
for p in child.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
if peft_key.endswith(".lora_A.weight"):
|
||||||
|
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||||
|
is_lora_a_loaded = True
|
||||||
|
elif peft_key.endswith(".lora_A.bias"):
|
||||||
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||||
|
elif peft_key.endswith(".lora_B.weight"):
|
||||||
|
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
|
||||||
|
is_lora_b_loaded = True
|
||||||
|
elif peft_key.endswith(".lora_B.bias"):
|
||||||
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||||
|
|
||||||
|
if is_lora_a_loaded and is_lora_b_loaded:
|
||||||
|
logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}")
|
||||||
|
elif is_lora_a_loaded or is_lora_b_loaded:
|
||||||
|
raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}")
|
||||||
|
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_adapter_memory_per_block(
|
||||||
|
block_config: transformers.PretrainedConfig,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
adapters: Sequence[str],
|
||||||
|
**load_peft_kwargs,
|
||||||
|
) -> int:
|
||||||
|
"""Get the number of extra bytes used to store a set of adapters per given block"""
|
||||||
|
with init_empty_weights(include_buffers=True):
|
||||||
|
block = block_config.block_class(block_config)
|
||||||
|
base_block_parameters = sum(p.numel() for p in block.parameters())
|
||||||
|
create_lora_adapter(block, quant_type=QuantType.NONE)
|
||||||
|
|
||||||
|
for adapter in adapters:
|
||||||
|
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
|
||||||
|
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
|
||||||
|
add_adapter_to_block(
|
||||||
|
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
|
||||||
|
)
|
||||||
|
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
|
||||||
|
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
|
||||||
|
return adapter_parameters * bytes_per_parameter
|
64
src/petals/utils/ping.py
Normal file
64
src/petals/utils/ping.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
from hivemind.proto import dht_pb2
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def ping(
|
||||||
|
peer_id: hivemind.PeerID,
|
||||||
|
_dht: hivemind.DHT,
|
||||||
|
node: hivemind.dht.DHTNode,
|
||||||
|
*,
|
||||||
|
wait_timeout: float = 5,
|
||||||
|
) -> float:
|
||||||
|
try:
|
||||||
|
ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
|
||||||
|
return time.perf_counter() - start_time
|
||||||
|
except Exception as e:
|
||||||
|
if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays)
|
||||||
|
return time.perf_counter() - start_time
|
||||||
|
|
||||||
|
logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
|
||||||
|
return math.inf
|
||||||
|
|
||||||
|
|
||||||
|
async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:
|
||||||
|
rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])
|
||||||
|
return dict(zip(peer_ids, rpc_infos))
|
||||||
|
|
||||||
|
|
||||||
|
class PingAggregator:
|
||||||
|
def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):
|
||||||
|
self.dht = dht
|
||||||
|
self.ema_alpha = ema_alpha
|
||||||
|
self.expiration = expiration
|
||||||
|
self.ping_emas = hivemind.TimedStorage()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
|
||||||
|
current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
|
||||||
|
logger.debug(f"Current RTTs: {current_rtts}")
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
expiration = hivemind.get_dht_time() + self.expiration
|
||||||
|
for peer_id, rtt in current_rtts.items():
|
||||||
|
prev_rtt = self.ping_emas.get(peer_id)
|
||||||
|
if prev_rtt is not None and prev_rtt.value != math.inf:
|
||||||
|
rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing
|
||||||
|
self.ping_emas.store(peer_id, rtt, expiration)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[hivemind.PeerID, float]:
|
||||||
|
with self.lock, self.ping_emas.freeze():
|
||||||
|
smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
|
||||||
|
logger.debug(f"Smothed RTTs: {smoothed_rtts}")
|
||||||
|
return smoothed_rtts
|
12
src/petals/utils/random.py
Normal file
12
src/petals/utils/random.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import random
|
||||||
|
from typing import Collection, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def sample_up_to(population: Collection[T], k: int) -> T:
|
||||||
|
if not isinstance(population, list):
|
||||||
|
population = list(population)
|
||||||
|
if len(population) > k:
|
||||||
|
population = random.sample(population, k)
|
||||||
|
return population
|
44
src/petals/utils/version.py
Normal file
44
src/petals/utils/version.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from hivemind.utils.logging import TextStyle, get_logger
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
|
import petals
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_version() -> None:
|
||||||
|
logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
|
||||||
|
try:
|
||||||
|
r = requests.get("https://pypi.python.org/pypi/petals/json")
|
||||||
|
r.raise_for_status()
|
||||||
|
response = r.json()
|
||||||
|
|
||||||
|
versions = [parse(ver) for ver in response.get("releases")]
|
||||||
|
latest = max(ver for ver in versions if not ver.is_prerelease)
|
||||||
|
|
||||||
|
if parse(petals.__version__) < latest:
|
||||||
|
logger.info(
|
||||||
|
f"A newer version {latest} is available. Please upgrade with: "
|
||||||
|
f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]:
|
||||||
|
if model_name_or_path is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path))
|
||||||
|
if match is None:
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones"
|
||||||
|
)
|
||||||
|
return match.group(1)
|
@ -1,25 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from huggingface_hub import delete_repo, list_models
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
|
|
||||||
parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
|
|
||||||
parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
|
|
||||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
|
||||||
parser.add_argument("--dry_run", action="store_true")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
for model in list_models(author=args.author, full=True):
|
|
||||||
last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
|
|
||||||
|
|
||||||
if model.modelId.endswith("-main") or "/test-" not in model.modelId:
|
|
||||||
continue # remove only test models
|
|
||||||
|
|
||||||
if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
|
|
||||||
if args.dry_run:
|
|
||||||
print(f"{model.modelId} can be deleted")
|
|
||||||
else:
|
|
||||||
delete_repo(repo_id=model.modelId, token=args.use_auth_token)
|
|
BIN
tests/server2.id
Normal file
BIN
tests/server2.id
Normal file
Binary file not shown.
@ -1,17 +1,46 @@
|
|||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from petals import AutoDistributedConfig
|
||||||
|
from petals.server.throughput import measure_compute_rps
|
||||||
|
from petals.utils.convert_block import QuantType
|
||||||
from test_utils import MODEL_NAME
|
from test_utils import MODEL_NAME
|
||||||
|
|
||||||
from petals.client import DistributedBloomConfig
|
|
||||||
from petals.server.throughput import measure_compute_rps, measure_network_rps
|
def test_bnb_not_imported_when_unnecessary():
|
||||||
|
"""
|
||||||
|
We avoid importing bitsandbytes when it's not used,
|
||||||
|
since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.
|
||||||
|
|
||||||
|
If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft
|
||||||
|
in the function's/method's code when it's actually needed instead of importing them in the beginning of the file.
|
||||||
|
This won't slow down the code - importing a module for the 2nd time doesn't rerun module code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subprocess.check_call([sys.executable, "-c", "import petals, sys; assert 'bitsandbytes' not in sys.modules"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_throughput_basic():
|
@pytest.mark.parametrize("inference", [False, True])
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
@pytest.mark.parametrize("n_tokens", [1, 16])
|
||||||
|
@pytest.mark.parametrize("tensor_parallel", [False, True])
|
||||||
|
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
|
||||||
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||||
|
if tensor_parallel and config.model_type != "bloom":
|
||||||
|
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
|
||||||
|
|
||||||
|
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
|
||||||
compute_rps = measure_compute_rps(
|
compute_rps = measure_compute_rps(
|
||||||
config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
|
config,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
quant_type=QuantType.NONE,
|
||||||
|
tensor_parallel_devices=tensor_parallel_devices,
|
||||||
|
n_tokens=n_tokens,
|
||||||
|
n_steps=5,
|
||||||
|
inference=inference,
|
||||||
)
|
)
|
||||||
assert isinstance(compute_rps, float) and compute_rps > 0
|
assert isinstance(compute_rps, float) and compute_rps > 0
|
||||||
network_rps = measure_network_rps(config)
|
|
||||||
assert isinstance(network_rps, float) and network_rps > 0
|
|
||||||
|
@ -1,43 +1,43 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import hivemind
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from test_utils import *
|
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import load_pretrained_block
|
from petals import AutoDistributedConfig, RemoteSequential
|
||||||
from petals.client import DistributedBloomConfig
|
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
|
||||||
from petals.client.remote_sequential import RemoteTransformerBlock
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
from petals.data_structures import UID_DELIMITER
|
from test_utils import *
|
||||||
from petals.dht_utils import get_remote_module
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
|
||||||
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
remote_sequential = RemoteSequential(config)
|
||||||
|
|
||||||
for block_index in random.sample(range(config.n_layer), 3):
|
block_index = random.randint(0, config.num_hidden_layers - 1)
|
||||||
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
|
remote_block = remote_sequential[block_index]
|
||||||
assert isinstance(remote_block, RemoteTransformerBlock)
|
|
||||||
|
|
||||||
inputs = torch.randn(1, 8, config.hidden_size)
|
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
|
||||||
outputs_forward = remote_block(inputs)
|
outputs_forward = remote_block(inputs)
|
||||||
|
|
||||||
outputs_inference = []
|
outputs_inference = []
|
||||||
|
with torch.inference_mode():
|
||||||
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
|
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
|
||||||
for i in range(inputs.shape[1]):
|
# Test long inference (unmerged inference pools)
|
||||||
|
outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
|
||||||
|
|
||||||
|
# Test short inference (merged inference pools)
|
||||||
|
for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
|
||||||
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
||||||
|
|
||||||
# test that max length is respected
|
# test that max length is respected
|
||||||
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
|
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
|
||||||
sess.step(inputs[:, -1:, :])
|
sess.step(inputs[:, -1:, :])
|
||||||
assert "Maximum length exceeded" in repr(exc_info.value)
|
assert "Maximum length exceeded" in repr(exc_info.value)
|
||||||
|
outputs_inference = torch.cat(outputs_inference, dim=1)
|
||||||
|
|
||||||
outputs_inference = torch.cat(outputs_inference, dim=1)
|
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
|
||||||
|
(outputs_local,) = ref_block(inputs)
|
||||||
|
|
||||||
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
|
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
|
||||||
(outputs_local,) = ref_block(inputs)
|
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
|
||||||
|
|
||||||
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
|
|
||||||
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
|
|
||||||
|
@ -4,22 +4,19 @@
|
|||||||
# - if you want to figure out chained inference, ask yozh
|
# - if you want to figure out chained inference, ask yozh
|
||||||
|
|
||||||
|
|
||||||
import hivemind
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from test_utils import *
|
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import load_pretrained_block
|
from petals import AutoDistributedConfig
|
||||||
from petals.client import DistributedBloomConfig
|
|
||||||
from petals.client.remote_sequential import RemoteSequential
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
from petals.dht_utils import get_remote_sequence
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
|
from test_utils import *
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
|
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
|
||||||
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
|
||||||
remote_blocks = get_remote_sequence(dht, 3, 6, config)
|
|
||||||
assert isinstance(remote_blocks, RemoteSequential)
|
assert isinstance(remote_blocks, RemoteSequential)
|
||||||
|
|
||||||
ref_blocks = [
|
ref_blocks = [
|
||||||
@ -46,10 +43,8 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
|
|||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_chained_inference_exact_match(atol_inference=1e-4):
|
def test_chained_inference_exact_match(atol_inference=1e-4):
|
||||||
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
|
||||||
remote_blocks = get_remote_sequence(dht, 3, 5, config)
|
|
||||||
assert isinstance(remote_blocks, RemoteSequential)
|
|
||||||
|
|
||||||
inputs = torch.randn(1, 8, config.hidden_size)
|
inputs = torch.randn(1, 8, config.hidden_size)
|
||||||
|
|
||||||
|
16
tests/test_dtype.py
Normal file
16
tests/test_dtype.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from test_utils import MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
|
||||||
|
def test_block_dtype(torch_dtype):
|
||||||
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||||
|
block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
|
||||||
|
expected_dtype = resolve_block_dtype(config, torch_dtype)
|
||||||
|
assert all(param.dtype == expected_dtype for param in block.parameters())
|
@ -1,28 +1,36 @@
|
|||||||
|
import peft
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from hivemind import get_logger
|
from hivemind import get_logger
|
||||||
|
from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
|
||||||
|
|
||||||
|
from petals import AutoDistributedModelForCausalLM
|
||||||
from test_utils import *
|
from test_utils import *
|
||||||
from transformers.generation import BeamSearchScorer
|
|
||||||
from transformers.models.bloom import BloomForCausalLM
|
|
||||||
|
|
||||||
from petals.client.remote_model import DistributedBloomForCausalLM
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer():
|
||||||
|
# We set use_fast=False since LlamaTokenizerFast is slow on load
|
||||||
|
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
|
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
|
||||||
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
|
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
|
||||||
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
|
def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
|
||||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
model = DistributedBloomForCausalLM.from_pretrained(
|
MODEL_NAME,
|
||||||
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
initial_peers=INITIAL_PEERS,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
active_adapter=ADAPTER_NAME if use_peft else None,
|
||||||
)
|
)
|
||||||
config = model.config
|
config = model.config
|
||||||
assert isinstance(model, DistributedBloomForCausalLM)
|
assert len(model.transformer.h) == model.config.num_hidden_layers
|
||||||
assert len(model.transformer.h) == model.config.n_layer
|
|
||||||
|
|
||||||
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
parallel_outputs = model.forward(test_inputs).logits
|
parallel_outputs = model.forward(test_inputs).logits
|
||||||
@ -37,8 +45,14 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
|
|||||||
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
||||||
|
|
||||||
for t in range(embs.shape[1]):
|
for t in range(embs.shape[1]):
|
||||||
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
|
if t == 4:
|
||||||
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
|
recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
|
||||||
|
elif 4 < t < 9:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
|
||||||
|
|
||||||
|
if t == 2 and pass_empty_tensors:
|
||||||
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
||||||
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
||||||
|
|
||||||
@ -51,9 +65,12 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
|
|||||||
del model, embs, recurrent_outputs
|
del model, embs, recurrent_outputs
|
||||||
|
|
||||||
if REF_NAME:
|
if REF_NAME:
|
||||||
ref_model = transformers.BloomForCausalLM.from_pretrained(
|
ref_model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
if use_peft:
|
||||||
|
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
|
||||||
|
ref_model.train(False)
|
||||||
if config.vocab_size < ref_model.config.vocab_size:
|
if config.vocab_size < ref_model.config.vocab_size:
|
||||||
ref_model.resize_token_embeddings(config.vocab_size)
|
ref_model.resize_token_embeddings(config.vocab_size)
|
||||||
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
|
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
|
||||||
@ -71,27 +88,29 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_greedy_generation(max_new_tokens=4):
|
def test_greedy_generation(tokenizer, max_new_tokens=4):
|
||||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
model = DistributedBloomForCausalLM.from_pretrained(
|
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
|
||||||
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
||||||
)
|
)
|
||||||
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
||||||
remote_outputs = model.generate(
|
remote_outputs = model.generate(
|
||||||
inputs,
|
inputs,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
)
|
)
|
||||||
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
|
hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
|
||||||
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
|
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
|
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
]
|
]
|
||||||
|
|
||||||
remote_outputs_batch = model.generate(
|
remote_outputs_batch = model.generate(
|
||||||
inputs_batch,
|
inputs_batch,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
)
|
)
|
||||||
hf_outputs_batch = BloomForCausalLM.greedy_search(
|
hf_outputs_batch = HfGenerationMixin.greedy_search(
|
||||||
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
|
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
|
||||||
)
|
)
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
@ -102,13 +121,13 @@ def test_greedy_generation(max_new_tokens=4):
|
|||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
|
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
|
||||||
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
|
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
|
||||||
def test_sampling(sampling_options, max_new_tokens=4):
|
def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
||||||
model = DistributedBloomForCausalLM.from_pretrained(
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
|
||||||
)
|
)
|
||||||
logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
|
logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
|
||||||
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
||||||
with torch.random.fork_rng():
|
with torch.random.fork_rng():
|
||||||
remote_outputs = model.generate(
|
remote_outputs = model.generate(
|
||||||
@ -118,7 +137,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
|
|||||||
**sampling_options,
|
**sampling_options,
|
||||||
)
|
)
|
||||||
with torch.random.fork_rng():
|
with torch.random.fork_rng():
|
||||||
hf_outputs = BloomForCausalLM.sample(
|
hf_outputs = HfGenerationMixin.sample(
|
||||||
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
|
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
|
||||||
)
|
)
|
||||||
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
|
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
|
||||||
@ -134,7 +153,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
|
|||||||
**sampling_options,
|
**sampling_options,
|
||||||
)
|
)
|
||||||
with torch.random.fork_rng():
|
with torch.random.fork_rng():
|
||||||
hf_outputs_batch = BloomForCausalLM.sample(
|
hf_outputs_batch = HfGenerationMixin.sample(
|
||||||
model,
|
model,
|
||||||
input_ids=inputs_batch,
|
input_ids=inputs_batch,
|
||||||
max_length=inputs_batch.size(1) + max_new_tokens,
|
max_length=inputs_batch.size(1) + max_new_tokens,
|
||||||
@ -146,10 +165,9 @@ def test_sampling(sampling_options, max_new_tokens=4):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_beam_search_generation(max_new_tokens=4, num_beams=2):
|
def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
|
||||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||||
model = DistributedBloomForCausalLM.from_pretrained(
|
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
|
||||||
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
||||||
)
|
)
|
||||||
text = "A cat sat on a mat"
|
text = "A cat sat on a mat"
|
||||||
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
|
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
|
||||||
@ -166,7 +184,7 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
|
|||||||
do_early_stopping=False,
|
do_early_stopping=False,
|
||||||
)
|
)
|
||||||
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
|
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
|
||||||
hf_outputs = BloomForCausalLM.beam_search(
|
hf_outputs = HfGenerationMixin.beam_search(
|
||||||
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
||||||
)
|
)
|
||||||
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
|
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
import bitsandbytes as bnb
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from bitsandbytes import functional as F
|
|
||||||
|
|
||||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt, get_inverse_transform_indices, undo_layout
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
|
||||||
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
|
||||||
)
|
|
||||||
def test_layout_exact_match():
|
|
||||||
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
|
|
||||||
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
|
|
||||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
|
||||||
tile_indices = get_inverse_transform_indices(transform, tile_size)
|
|
||||||
cxb = transform(x)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
restored_x = undo_layout(cxb, tile_indices)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
assert restored_x.is_contiguous()
|
|
||||||
assert torch.all(torch.eq(restored_x, x))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
|
||||||
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
|
||||||
)
|
|
||||||
def test_linear_exact_match():
|
|
||||||
linear = torch.nn.Linear(1024, 3072)
|
|
||||||
x = torch.randn(3, 1024, dtype=torch.half)
|
|
||||||
linear8bitlt = bnb.nn.Linear8bitLt(
|
|
||||||
linear.in_features,
|
|
||||||
linear.out_features,
|
|
||||||
linear.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=6.0,
|
|
||||||
memory_efficient_backward=True,
|
|
||||||
)
|
|
||||||
linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
|
|
||||||
linear.weight.dtype
|
|
||||||
)
|
|
||||||
linear8bitlt.bias = linear.bias
|
|
||||||
linear8bitlt.cuda()
|
|
||||||
|
|
||||||
linear_custom = CustomLinear8bitLt(
|
|
||||||
linear.in_features,
|
|
||||||
linear.out_features,
|
|
||||||
linear.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=6.0,
|
|
||||||
)
|
|
||||||
linear_custom.weight = bnb.nn.Int8Params(
|
|
||||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
|
||||||
).to(linear.weight.dtype)
|
|
||||||
linear_custom.bias = linear.bias
|
|
||||||
linear_custom.cuda()
|
|
||||||
|
|
||||||
x_ref = x.clone().cuda().requires_grad_(True)
|
|
||||||
x_ours = x.clone().cuda().requires_grad_(True)
|
|
||||||
fx_ref = linear8bitlt(x_ref).float()
|
|
||||||
grad_proj = torch.randn_like(fx_ref)
|
|
||||||
(fx_ref * grad_proj).mean().backward()
|
|
||||||
|
|
||||||
fx_ours = linear_custom(x_ours).float()
|
|
||||||
(fx_ours * grad_proj).mean().backward()
|
|
||||||
assert torch.equal(fx_ref, fx_ours)
|
|
||||||
assert torch.allclose(x_ref.grad, x_ours.grad)
|
|
||||||
assert not linear_custom.state.has_fp16_weights
|
|
||||||
assert linear_custom.state.CB is None
|
|
||||||
assert linear_custom.state.CxB is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
|
||||||
def test_linear_no_igemmlt():
|
|
||||||
linear = torch.nn.Linear(1024, 3072)
|
|
||||||
x = torch.randn(3, 1024, dtype=torch.half)
|
|
||||||
linear_custom = CustomLinear8bitLt(
|
|
||||||
linear.in_features,
|
|
||||||
linear.out_features,
|
|
||||||
linear.bias is not None,
|
|
||||||
has_fp16_weights=False,
|
|
||||||
threshold=6.0,
|
|
||||||
)
|
|
||||||
linear_custom.state.force_no_igemmlt = True
|
|
||||||
|
|
||||||
linear_custom.weight = bnb.nn.Int8Params(
|
|
||||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
|
||||||
).to(linear.weight.dtype)
|
|
||||||
linear_custom.bias = linear.bias
|
|
||||||
linear_custom.cuda()
|
|
||||||
linear.half().cuda()
|
|
||||||
|
|
||||||
x_ref = x.clone().cuda().requires_grad_(True)
|
|
||||||
x_ours = x.clone().cuda().requires_grad_(True)
|
|
||||||
fx_ref = linear(x_ref).float()
|
|
||||||
grad_proj = torch.randn_like(fx_ref)
|
|
||||||
(fx_ref * grad_proj).mean().backward()
|
|
||||||
|
|
||||||
fx_ours = linear_custom(x_ours).float()
|
|
||||||
(fx_ours * grad_proj).mean().backward()
|
|
||||||
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
|
|
||||||
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
|
|
||||||
assert not linear_custom.state.has_fp16_weights
|
|
||||||
assert linear_custom.state.CB is not None
|
|
||||||
assert linear_custom.state.CxB is None
|
|
66
tests/test_peft.py
Normal file
66
tests/test_peft.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from petals.utils.peft import check_peft_repository, load_peft
|
||||||
|
|
||||||
|
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
|
||||||
|
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
|
||||||
|
TMP_CACHE_DIR = "tmp_cache/"
|
||||||
|
|
||||||
|
|
||||||
|
def clear_dir(path_to_dir):
|
||||||
|
shutil.rmtree(path_to_dir)
|
||||||
|
os.mkdir(path_to_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def dir_empty(path_to_dir):
|
||||||
|
files = os.listdir(path_to_dir)
|
||||||
|
return len(files) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_check_peft():
|
||||||
|
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
|
||||||
|
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_noncached(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
|
||||||
|
|
||||||
|
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_cached(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_layer_exists(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
|
||||||
|
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_load_layer_nonexists(tmpdir):
|
||||||
|
clear_dir(tmpdir)
|
||||||
|
|
||||||
|
load_peft(
|
||||||
|
SAFE_PEFT_REPO,
|
||||||
|
block_idx=1337,
|
||||||
|
cache_dir=tmpdir,
|
||||||
|
)
|
@ -1,25 +1,26 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from hivemind import DHT, BatchTensorDescriptor, get_logger
|
from hivemind import DHT, BatchTensorDescriptor, get_logger
|
||||||
from hivemind.proto import runtime_pb2
|
from hivemind.proto import runtime_pb2
|
||||||
|
|
||||||
|
from petals import AutoDistributedConfig
|
||||||
|
from petals.client import RemoteSequenceManager, RemoteSequential
|
||||||
|
from petals.data_structures import UID_DELIMITER
|
||||||
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
from test_utils import *
|
from test_utils import *
|
||||||
|
|
||||||
from petals.bloom.from_pretrained import load_pretrained_block
|
logger = get_logger(__name__)
|
||||||
from petals.client import RemoteSequenceManager, RemoteSequential
|
|
||||||
from petals.client.remote_model import DistributedBloomConfig
|
|
||||||
from petals.data_structures import UID_DELIMITER
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_remote_sequential():
|
def test_remote_sequential():
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||||
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
||||||
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
|
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
|
||||||
grad_proj = torch.randn(1, 5, config.hidden_size)
|
grad_proj = torch.randn(1, 5, config.hidden_size)
|
||||||
|
|
||||||
sequential = RemoteSequential(config, dht)
|
sequential = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
full_outputs = sequential(test_inputs)
|
full_outputs = sequential(test_inputs)
|
||||||
(full_outputs * grad_proj).sum().backward()
|
(full_outputs * grad_proj).sum().backward()
|
||||||
@ -27,10 +28,10 @@ def test_remote_sequential():
|
|||||||
full_grad = test_inputs.grad.clone()
|
full_grad = test_inputs.grad.clone()
|
||||||
test_inputs.grad.data.zero_()
|
test_inputs.grad.data.zero_()
|
||||||
|
|
||||||
first_half = sequential[: config.n_layer // 2]
|
first_half = sequential[: config.num_hidden_layers // 2]
|
||||||
second_half = sequential[config.n_layer // 2 :]
|
second_half = sequential[config.num_hidden_layers // 2 :]
|
||||||
assert len(first_half) + len(second_half) == len(sequential)
|
assert len(first_half) + len(second_half) == len(sequential)
|
||||||
assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
|
assert abs(len(first_half) - len(second_half)) == config.num_hidden_layers % 2
|
||||||
for m in sequential, first_half, second_half:
|
for m in sequential, first_half, second_half:
|
||||||
assert isinstance(repr(m), str)
|
assert isinstance(repr(m), str)
|
||||||
|
|
||||||
@ -39,15 +40,15 @@ def test_remote_sequential():
|
|||||||
assert hidden.shape == test_inputs.shape
|
assert hidden.shape == test_inputs.shape
|
||||||
assert hidden.requires_grad
|
assert hidden.requires_grad
|
||||||
second_half_outputs = second_half(hidden)
|
second_half_outputs = second_half(hidden)
|
||||||
assert torch.allclose(second_half_outputs, full_outputs)
|
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
|
||||||
|
|
||||||
(second_half_outputs * grad_proj).sum().backward()
|
(second_half_outputs * grad_proj).sum().backward()
|
||||||
assert torch.allclose(test_inputs.grad, full_grad)
|
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
|
||||||
|
|
||||||
# test RemoteSequential with lossy compression
|
# test RemoteSequential with lossy compression
|
||||||
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
|
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
|
||||||
lossy_sequential = RemoteSequential(
|
lossy_sequential = RemoteSequential(
|
||||||
config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p, start=True)
|
config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_inputs.grad = None
|
test_inputs.grad = None
|
||||||
@ -55,10 +56,10 @@ def test_remote_sequential():
|
|||||||
(approx_outputs * grad_proj).sum().backward()
|
(approx_outputs * grad_proj).sum().backward()
|
||||||
|
|
||||||
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
|
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
|
||||||
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
|
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), "compression was not used"
|
||||||
assert abs(approx_outputs - full_outputs).mean() < 0.01
|
assert abs(approx_outputs - full_outputs).mean() < 0.01
|
||||||
absmax = abs(full_grad).max()
|
absmax = abs(full_grad).max()
|
||||||
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01
|
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
|
||||||
|
|
||||||
|
|
||||||
class DummyCustomSequenceManager(RemoteSequenceManager):
|
class DummyCustomSequenceManager(RemoteSequenceManager):
|
||||||
@ -77,20 +78,24 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
|
|||||||
if protocol == "rpc_forward":
|
if protocol == "rpc_forward":
|
||||||
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
|
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
|
||||||
elif protocol == "rpc_backward":
|
elif protocol == "rpc_backward":
|
||||||
metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,)
|
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
|
||||||
|
# FIXME: Initially, we used CompressionType.BLOCKWISE_8BIT for rpc_backward() here.
|
||||||
|
# This is currently broken since hivemind==1.1.8 is not compatible with bitsandbytes==0.39.1.
|
||||||
|
# Please revert to BLOCKWISE_8BIT once this is fixed: https://github.com/learning-at-home/hivemind/issues/572
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
|
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||||
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
remote_sequential = RemoteSequential(config)
|
||||||
remote_sequential = RemoteSequential(config, dht)
|
|
||||||
|
|
||||||
inputs = torch.randn(batch_size, seq_len, config.hidden_size)
|
inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
|
||||||
output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
|
output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
|
||||||
input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
|
input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
|
||||||
intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
|
intermediate_prompts = torch.randn(
|
||||||
|
config.num_hidden_layers, batch_size, pre_seq_len, config.hidden_size, requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
input_prompts = input_prompts.detach().requires_grad_(True)
|
input_prompts = input_prompts.detach().requires_grad_(True)
|
||||||
intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
|
intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
|
||||||
@ -110,17 +115,17 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
|
|||||||
assert intermediate_prompts_ref.grad is None
|
assert intermediate_prompts_ref.grad is None
|
||||||
|
|
||||||
outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
|
outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
|
||||||
for block_index in range(config.n_layer):
|
for block_index in range(config.num_hidden_layers):
|
||||||
block_prompt = intermediate_prompts_ref[block_index]
|
block_prompt = intermediate_prompts_ref[block_index]
|
||||||
outputs_ref[:, : block_prompt.shape[1]] += block_prompt
|
outputs_ref[:, : block_prompt.shape[1]] += block_prompt
|
||||||
|
|
||||||
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
|
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
|
||||||
(outputs_ref,) = block(outputs_ref)
|
(outputs_ref,) = block(outputs_ref)
|
||||||
|
|
||||||
assert torch.allclose(outputs_ref, outputs)
|
assert torch.allclose(outputs_ref, outputs, atol=1e-3)
|
||||||
|
|
||||||
(outputs_ref * output_proj).sum().backward()
|
(outputs_ref * output_proj).sum().backward()
|
||||||
assert input_prompts_ref.grad is not None
|
assert input_prompts_ref.grad is not None
|
||||||
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
|
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
|
||||||
assert intermediate_prompts_ref.grad is not None
|
assert intermediate_prompts_ref.grad is not None
|
||||||
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)
|
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)
|
||||||
|
@ -4,31 +4,31 @@ import time
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from hivemind import DHT, get_logger
|
from hivemind import DHT, get_logger
|
||||||
|
|
||||||
|
from petals import AutoDistributedConfig
|
||||||
|
from petals.client import RemoteSequenceManager, RemoteSequential
|
||||||
|
from petals.data_structures import UID_DELIMITER
|
||||||
from test_utils import *
|
from test_utils import *
|
||||||
|
|
||||||
from petals.client import RemoteSequenceManager, RemoteSequential
|
logger = get_logger(__name__)
|
||||||
from petals.client.remote_model import DistributedBloomConfig
|
|
||||||
from petals.data_structures import UID_DELIMITER
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_sequence_manager_shutdown():
|
@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
|
||||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
def test_sequence_manager_basics(mode: str):
|
||||||
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||||
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
||||||
sequential = RemoteSequential(config, dht)
|
sequential = RemoteSequential(config, dht=dht)
|
||||||
shutdown_evt = threading.Event()
|
shutdown_evt = threading.Event()
|
||||||
|
|
||||||
# test RemoteSequential with lossy compression
|
# test RemoteSequential with lossy compression
|
||||||
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
|
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
|
||||||
sequential = RemoteSequential(
|
sequential = RemoteSequential(
|
||||||
config,
|
config,
|
||||||
dht,
|
sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
|
||||||
sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence = sequential.sequence_manager.make_sequence()
|
sequence = sequential.sequence_manager.make_sequence(mode=mode)
|
||||||
assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1))
|
assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1))
|
||||||
|
|
||||||
assert sequential.sequence_manager.is_alive()
|
assert sequential.sequence_manager.is_alive()
|
||||||
@ -43,7 +43,7 @@ def test_sequence_manager_shutdown():
|
|||||||
assert shutdown_evt.is_set()
|
assert shutdown_evt.is_set()
|
||||||
|
|
||||||
|
|
||||||
class TestSequenceManager(RemoteSequenceManager):
|
class RemoteSequenceManagerWithChecks(RemoteSequenceManager):
|
||||||
"""A sequence manager that signals if it was shut down"""
|
"""A sequence manager that signals if it was shut down"""
|
||||||
|
|
||||||
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
|
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
|
||||||
|
39
tests/test_server_stats.py
Normal file
39
tests/test_server_stats.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from petals import AutoDistributedConfig, RemoteSequential
|
||||||
|
from petals.server.handler import CACHE_TOKENS_AVAILABLE
|
||||||
|
from test_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
|
||||||
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||||
|
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
|
||||||
|
|
||||||
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
||||||
|
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
|
||||||
|
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)
|
||||||
|
|
||||||
|
info_before = blocks1.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
with blocks1.inference_session(max_length=max_length) as sess:
|
||||||
|
sess.step(torch.randn(1, 1, config.hidden_size))
|
||||||
|
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
||||||
|
info_inside = blocks1.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
with blocks2.inference_session(max_length=max_length2) as sess2:
|
||||||
|
sess2.step(torch.randn(1, 1, config.hidden_size))
|
||||||
|
blocks2.sequence_manager.state.rpc_info = None # invalidate cache
|
||||||
|
info_inside2 = blocks2.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
time.sleep(0.1)
|
||||||
|
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
||||||
|
info_after = blocks1.sequence_manager.rpc_info
|
||||||
|
|
||||||
|
assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]
|
||||||
|
assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1)
|
||||||
|
assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)
|
49
tests/test_tensor_parallel.py
Normal file
49
tests/test_tensor_parallel.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from tensor_parallel import TensorParallel
|
||||||
|
from tensor_parallel.slicing_configs import get_bloom_config
|
||||||
|
|
||||||
|
from petals.server.from_pretrained import load_pretrained_block
|
||||||
|
from test_utils import MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.forked
|
||||||
|
@pytest.mark.parametrize("custom_config", [True, False])
|
||||||
|
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
|
||||||
|
def test_tp_block(devices, custom_config):
|
||||||
|
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
||||||
|
if model_config.model_type != "bloom":
|
||||||
|
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
|
||||||
|
|
||||||
|
block_index = random.randint(0, 10)
|
||||||
|
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
|
||||||
|
|
||||||
|
tp_config = None
|
||||||
|
if custom_config:
|
||||||
|
tp_config = get_bloom_config(model_config, devices)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
prefix_length = 5
|
||||||
|
|
||||||
|
test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
|
||||||
|
test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
|
||||||
|
test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
|
||||||
|
test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
|
||||||
|
grad_proj = torch.rand_like(test_inputs1)
|
||||||
|
|
||||||
|
y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
|
||||||
|
y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
|
||||||
|
y_ref.backward(grad_proj)
|
||||||
|
|
||||||
|
block_tp = TensorParallel(block, devices, config=tp_config)
|
||||||
|
y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
|
||||||
|
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
|
||||||
|
y_ours.backward(grad_proj)
|
||||||
|
|
||||||
|
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
|
||||||
|
assert torch.allclose(y_ours, y_ref, atol=1e-5)
|
||||||
|
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
|
||||||
|
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)
|
@ -11,3 +11,5 @@ if not MODEL_NAME:
|
|||||||
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
|
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
|
||||||
|
|
||||||
REF_NAME = os.environ.get("REF_NAME")
|
REF_NAME = os.environ.get("REF_NAME")
|
||||||
|
|
||||||
|
ADAPTER_NAME = os.environ.get("ADAPTER_NAME")
|
||||||
|
Loading…
Reference in New Issue
Block a user