Make Petals a pip-installable package (attempt 2) (#102)

1. Petals can be now installed using `pip install git+https://github.com/bigscience-workshop/petals`
    - In case if you already cloned the repo, you can do `pip install .` or `pip install .[dev]`
2. Moved `src` => `src/petals`
    - Replaced `from src.smth import smth` with `from petals.smth import smth`
3. Moved `cli` => `src/petals/cli`
    - Replaced `python -m cli.run_smth` with `python -m petals.cli.run_smth` (all utilities are now available right after pip installation)
4. Moved the `requirements*.txt` contents to `setup.cfg` (`requirements.txt` for packages is not supported well by modern packaging utils)
5. Increased the package version from `0.2` to `1.0alpha1`
fix-ptune
Alexander Borzunov 1 year ago committed by GitHub
parent 0c3781a89c
commit 7bd5916744
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,11 +21,11 @@ jobs:
uses: actions/cache@v2 uses: actions/cache@v2
with: with:
path: ~/.cache/pip path: ~/.cache/pip
key: Key-v1-py3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} key: Key-v1-py3.9-${{ hashFiles('setup.cfg') }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -r requirements.txt pip install .
- name: Delete any test models older than 1 week - name: Delete any test models older than 1 week
run: | run: |
python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
@ -37,7 +37,7 @@ jobs:
- name: Convert model and push to hub - name: Convert model and push to hub
run: | run: |
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
python -m cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \
--output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \ --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \
--resize_token_embeddings 50000 --resize_token_embeddings 50000
@ -59,19 +59,18 @@ jobs:
uses: actions/cache@v2 uses: actions/cache@v2
with: with:
path: ~/.cache/pip path: ~/.cache/pip
key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -r requirements.txt pip install .[dev]
pip install -r requirements-dev.txt
- name: Test - name: Test
run: | run: |
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG
export REF_NAME=bigscience/bloom-560m export REF_NAME=bigscience/bloom-560m
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log & --torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log &
SERVER1_PID=$! SERVER1_PID=$!
@ -81,21 +80,21 @@ jobs:
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \ python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log & --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log &
SERVER2_PID=$! SERVER2_PID=$!
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \ python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log & --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
SERVER3_PID=$! SERVER3_PID=$!
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \ python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log & --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
SERVER4_PID=$! SERVER4_PID=$!
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log & --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log &
SERVER5_PID=$! SERVER5_PID=$!

@ -85,10 +85,10 @@ This is important because it's technically possible for peers serving model laye
## Installation ## Installation
Here's how to install the dependencies with conda: Here's how to install Petals with conda:
``` ```
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt pip install git+https://github.com/bigscience-workshop/petals
``` ```
This script uses Anaconda to install cuda-enabled PyTorch. This script uses Anaconda to install cuda-enabled PyTorch.
@ -107,7 +107,7 @@ For a detailed instruction with larger models, see ["Launch your own swarm"](htt
First, run a couple of servers, each in a separate shell. To launch your first server, run: First, run a couple of servers, each in a separate shell. To launch your first server, run:
```bash ```bash
python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \ python -m petals.cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
--host_maddrs /ip4/127.0.0.1/tcp/31337 # use port 31337, local connections only --host_maddrs /ip4/127.0.0.1/tcp/31337 # use port 31337, local connections only
``` ```
@ -124,7 +124,7 @@ Mon Day 01:23:45.678 [INFO] Running DHT node on ['/ip4/127.0.0.1/tcp/31337/p2p/A
You can use this address (`/ip4/whatever/else`) to connect additional servers. Open another terminal and run: You can use this address (`/ip4/whatever/else`) to connect additional servers. Open another terminal and run:
```bash ```bash
python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \ python -m petals.cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
--host_maddrs /ip4/127.0.0.1/tcp/0 \ --host_maddrs /ip4/127.0.0.1/tcp/0 \
--initial_peers /ip4/127.0... # <-- TODO: Copy the address of another server here --initial_peers /ip4/127.0... # <-- TODO: Copy the address of another server here
# e.g. --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS1GecIfYouAreReadingThisYouNeedToCopyYourServerAddressCBBq # e.g. --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS1GecIfYouAreReadingThisYouNeedToCopyYourServerAddressCBBq
@ -140,11 +140,11 @@ Once your have enough servers, you can use them to train and/or inference the mo
```python ```python
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers from transformers import BloomTokenizerFast
from src import DistributedBloomForCausalLM from petals.client import DistributedBloomForCausalLM
initial_peers = [TODO_put_one_or_more_server_addresses_here] # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"] initial_peers = [TODO_put_one_or_more_server_addresses_here] # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main") tokenizer = BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
model = DistributedBloomForCausalLM.from_pretrained( model = DistributedBloomForCausalLM.from_pretrained(
"bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32 "bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
) # this model has only embeddings / logits, all transformer blocks rely on remote servers ) # this model has only embeddings / logits, all transformer blocks rely on remote servers
@ -170,21 +170,26 @@ Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/
## 🛠️ Development ## 🛠️ Development
Petals uses pytest with a few plugins. To install them, run `pip install -r requirements-dev.txt` Petals uses pytest with a few plugins. To install them, run:
```python
git clone https://github.com/bigscience-workshop/petals.git && cd petals
pip install -e .[dev]
```
To run minimalistic tests, spin up some servers: To run minimalistic tests, spin up some servers:
```bash ```bash
export MODEL_NAME=bloom-testing/test-bloomd-560m-main export MODEL_NAME=bloom-testing/test-bloomd-560m-main
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
python -m cli.run_server $MODEL_NAME --block_indices 0:12 --throughput 1 --torch_dtype float32 \ python -m petals.cli.run_server $MODEL_NAME --block_indices 0:12 --throughput 1 --torch_dtype float32 \
--identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> server1.log & --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> server1.log &
sleep 5 # wait for the first server to initialize DHT sleep 5 # wait for the first server to initialize DHT
python -m cli.run_server $MODEL_NAME --block_indices 12:24 --throughput 1 --torch_dtype float32 \ python -m petals.cli.run_server $MODEL_NAME --block_indices 12:24 --throughput 1 --torch_dtype float32 \
--initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g &> server2.log & --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g &> server2.log &
tail -f server1.log server2.log # view logs for both servers tail -f server1.log server2.log # view logs for both servers
# after you're done, kill servers with 'pkill -f cli.run_server' # after you're done, kill servers with 'pkill -f petals.cli.run_server'
``` ```
Then launch pytest: Then launch pytest:

@ -36,22 +36,15 @@
"import subprocess\n", "import subprocess\n",
"import sys\n", "import sys\n",
"\n", "\n",
"!pip install -r git+https://github.com/bigscience-workshop/petals\n",
"!pip install datasets wandb\n",
"\n", "\n",
"IN_COLAB = 'google.colab' in sys.modules\n", "IN_COLAB = 'google.colab' in sys.modules\n",
"\n", "if IN_COLAB: # Remove CUDA binaries on CPU-only colabs to not confuse bitsandbytes\n",
"if IN_COLAB:\n",
" subprocess.run(\"git clone https://github.com/bigscience-workshop/petals\", shell=True)\n",
" subprocess.run(\"pip install -r petals/requirements.txt\", shell=True)\n",
" subprocess.run(\"pip install datasets wandb\", shell=True)\n",
"\n",
" try:\n", " try:\n",
" subprocess.check_output([\"nvidia-smi\", \"-L\"])\n", " subprocess.check_output([\"nvidia-smi\", \"-L\"])\n",
" except subprocess.CalledProcessError as e:\n", " except subprocess.CalledProcessError as e:\n",
" subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)\n", " subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)"
"\n",
" sys.path.insert(0, './petals/')\n",
"else:\n",
" sys.path.insert(0, \"..\")"
] ]
}, },
{ {
@ -62,7 +55,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import sys\n",
" \n", " \n",
"import torch\n", "import torch\n",
"import transformers\n", "import transformers\n",
@ -71,10 +63,10 @@
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"from torch.optim import AdamW\n", "from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from transformers import get_scheduler\n", "from transformers import BloomTokenizerFast, get_scheduler\n",
"\n", "\n",
"# Import a Petals model\n", "# Import a Petals model\n",
"from src.client.remote_model import DistributedBloomForCausalLM" "from petals.client.remote_model import DistributedBloomForCausalLM"
] ]
}, },
{ {
@ -120,7 +112,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
"tokenizer.padding_side = 'right'\n", "tokenizer.padding_side = 'right'\n",
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
"model = DistributedBloomForCausalLM.from_pretrained(\n", "model = DistributedBloomForCausalLM.from_pretrained(\n",
@ -314,7 +306,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3.8.0 ('petals')", "display_name": "Python 3.6.9 64-bit",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -328,11 +320,11 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.0" "version": "3.6.9"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "a303c9f329a09f921588ea6ef03898c90b4a8e255a47e0bd6e36f6331488f609" "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
} }
} }
}, },

@ -36,22 +36,15 @@
"import subprocess\n", "import subprocess\n",
"import sys\n", "import sys\n",
"\n", "\n",
"!pip install -r git+https://github.com/bigscience-workshop/petals\n",
"!pip install datasets wandb\n",
"\n", "\n",
"IN_COLAB = 'google.colab' in sys.modules\n", "IN_COLAB = 'google.colab' in sys.modules\n",
"\n", "if IN_COLAB: # Remove CUDA binaries on CPU-only colabs to not confuse bitsandbytes\n",
"if IN_COLAB:\n",
" subprocess.run(\"git clone https://github.com/bigscience-workshop/petals\", shell=True)\n",
" subprocess.run(\"pip install -r petals/requirements.txt\", shell=True)\n",
" subprocess.run(\"pip install datasets wandb\", shell=True)\n",
"\n",
" try:\n", " try:\n",
" subprocess.check_output([\"nvidia-smi\", \"-L\"])\n", " subprocess.check_output([\"nvidia-smi\", \"-L\"])\n",
" except subprocess.CalledProcessError as e:\n", " except subprocess.CalledProcessError as e:\n",
" subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)\n", " subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)"
"\n",
" sys.path.insert(0, './petals/')\n",
"else:\n",
" sys.path.insert(0, \"..\")"
] ]
}, },
{ {
@ -62,7 +55,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import sys\n",
" \n", " \n",
"import torch\n", "import torch\n",
"import transformers\n", "import transformers\n",
@ -71,10 +63,10 @@
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"from torch.optim import AdamW\n", "from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from transformers import get_scheduler\n", "from transformers import BloomTokenizerFast, get_scheduler\n",
"\n", "\n",
"# Import a Petals model\n", "# Import a Petals model\n",
"from src.client.remote_model import DistributedBloomForSequenceClassification" "from petals.client.remote_model import DistributedBloomForSequenceClassification"
] ]
}, },
{ {
@ -121,7 +113,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
"tokenizer.padding_side = 'right'\n", "tokenizer.padding_side = 'right'\n",
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
"model = DistributedBloomForSequenceClassification.from_pretrained(\n", "model = DistributedBloomForSequenceClassification.from_pretrained(\n",
@ -313,7 +305,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.9" "version": "3.6.9"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

@ -1,3 +1,10 @@
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"
[tool.black] [tool.black]
line-length = 120 line-length = 120
required-version = "22.3.0" required-version = "22.3.0"

@ -0,0 +1,54 @@
[metadata]
name = petals
version = 1.0alpha1
author = Petals Developers
author_email = petals-dev@googlegroups.com
description = Easy way to efficiently run 100B+ language models without high-end GPUs
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/bigscience-workshop/petals
project_urls =
Bug Tracker = https://github.com/bigscience-workshop/petals/issues
classifiers =
Development Status :: 4 - Beta
Intended Audience :: Developers
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Programming Language :: Python :: 3
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Mathematics
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Software Development
Topic :: Software Development :: Libraries
Topic :: Software Development :: Libraries :: Python Modules
[options]
package_dir =
= src
packages = find:
python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.34.0
accelerate==0.10.0
huggingface-hub==0.7.0
transformers==4.21.3
protobuf>=3.20.3,<4.0dev
hivemind>=1.1.3
humanfriendly
async-timeout>=4.0.2
[options.extras_require]
dev =
pytest==6.2.5
pytest-forked
pytest-asyncio==0.16.0
black==22.3.0
isort==5.10.1
psutil
[options.packages.find]
where = src

@ -1,6 +0,0 @@
from src.bloom import *
from src.client import *
from src.dht_utils import declare_active_modules, get_remote_module
project_name = "bloomd"
__version__ = "0.2"

@ -1,2 +0,0 @@
from src.bloom.block import BloomBlock
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

@ -1,5 +0,0 @@
from src.client.inference_session import InferenceSession
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager
from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -0,0 +1 @@
__version__ = "1.0alpha1"

@ -0,0 +1,2 @@
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

@ -9,7 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear import torch.nn.quantized.dynamic.modules.linear
from src.bloom.ops import ( from petals.bloom.ops import (
BloomGelu, BloomGelu,
BloomScaledSoftmax, BloomScaledSoftmax,
attention_mask_func, attention_mask_func,

@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers.modeling_utils import WEIGHTS_NAME from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url from transformers.utils.hub import cached_path, hf_bucket_url
from src.bloom import BloomBlock, BloomConfig from petals.bloom import BloomBlock, BloomConfig
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -26,7 +26,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from src.bloom.block import BloomBlock from petals.bloom.block import BloomBlock
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = logging.get_logger(__file__) logger = logging.get_logger(__file__)

@ -9,9 +9,9 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from huggingface_hub import Repository from huggingface_hub import Repository
from tqdm.auto import tqdm from tqdm.auto import tqdm
from src import BloomModel from petals.bloom import BloomModel
from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from src.client import DistributedBloomConfig from petals.client import DistributedBloomConfig
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -32,7 +32,7 @@ while getopts ":m:i:d:p:b:a:t:" option; do
;; ;;
b) BLOCK_IDS=${OPTARG} b) BLOCK_IDS=${OPTARG}
;; ;;
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
;; ;;
t) RUN_LOCAL_TESTS=true t) RUN_LOCAL_TESTS=true
;; ;;
@ -67,7 +67,7 @@ else
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt pip install -i https://pypi.org/simple -r .
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
fi fi
@ -75,5 +75,5 @@ fi
# Run server # # Run server #
############## ##############
python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \ python -m petals.cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log

@ -4,9 +4,9 @@ import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tqdm.auto import trange from tqdm.auto import trange
from src.bloom.block import BloomBlock from petals.bloom.block import BloomBlock
from src.bloom.model import BloomConfig from petals.bloom.model import BloomConfig
from src.bloom.ops import build_alibi_tensor from petals.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -40,7 +40,7 @@ else
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt pip install -i https://pypi.org/simple -r .
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
fi fi
@ -59,7 +59,7 @@ echo "Initial peer: ${INITIAL_PEER}"
# Initialize the config file # # Initialize the config file #
############################## ##############################
typeset -A cfg typeset -A cfg
cfg=( # set default values in config array cfg=( # set default values in config array
[device]="cpu" [device]="cpu"
[block_ids]="1:2" [block_ids]="1:2"
@ -72,7 +72,7 @@ cfg=( # set default values in config array
############### ###############
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) ) for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
do do
############### ###############
# Read config # # Read config #
############### ###############
@ -85,14 +85,14 @@ do
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-) cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
fi fi
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
echo "=== Server #${SERVER_ID} ===" echo "=== Server #${SERVER_ID} ==="
echo "Server ID: ${cfg[id_path]}" echo "Server ID: ${cfg[id_path]}"
echo "Device: ${cfg[device]}" echo "Device: ${cfg[device]}"
echo "Bloom block ids: ${cfg[block_ids]}" echo "Bloom block ids: ${cfg[block_ids]}"
echo "Host maddr: ${cfg[maddr]}" echo "Host maddr: ${cfg[maddr]}"
echo "" echo ""
############## ##############
# Run server # # Run server #
############## ##############

@ -45,7 +45,7 @@ else
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt pip install -i https://pypi.org/simple -r .
fi fi
@ -65,7 +65,7 @@ echo "Initial peer: ${INITIAL_PEER}"
# Initialize the config file # # Initialize the config file #
############################## ##############################
typeset -A cfg typeset -A cfg
cfg=( # set default values in config array cfg=( # set default values in config array
[name]="" [name]=""
[device]="cpu" [device]="cpu"
@ -79,7 +79,7 @@ cfg=( # set default values in config array
############### ###############
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) ) for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
do do
############### ###############
# Read config # # Read config #
############### ###############
@ -92,7 +92,7 @@ do
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-) cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
fi fi
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
SERVER_NAME="${USERNAME}@${cfg[name]}" SERVER_NAME="${USERNAME}@${cfg[name]}"
echo "=== Server #${SERVER_ID} ===" echo "=== Server #${SERVER_ID} ==="
echo "Server name ${SERVER_NAME}" echo "Server name ${SERVER_NAME}"
@ -101,10 +101,10 @@ do
echo "Bloom block ids: ${cfg[block_ids]}" echo "Bloom block ids: ${cfg[block_ids]}"
echo "Host maddr: ${cfg[maddr]}" echo "Host maddr: ${cfg[maddr]}"
echo "=================" echo "================="
############## ##############
# Run server # # Run server #
############## ##############
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'" ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
done done

@ -6,8 +6,8 @@ from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from humanfriendly import parse_size from humanfriendly import parse_size
from src.constants import PUBLIC_INITIAL_PEERS from petals.constants import PUBLIC_INITIAL_PEERS
from src.server.server import Server from petals.server.server import Server
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -0,0 +1,5 @@
from petals.client.inference_session import InferenceSession
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -20,10 +20,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2 from hivemind.proto import runtime_pb2
from src.client.sequence_manager import RemoteSequenceManager from petals.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler from petals.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__) logger = get_logger(__file__)

@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming from hivemind.utils.streaming import split_for_streaming
from src.data_structures import ModuleUID, RPCInfo from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary( async def _forward_unary(

@ -3,14 +3,14 @@ from typing import List, Optional
import torch import torch
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from src.utils.generation_algorithms import ( from petals.utils.generation_algorithms import (
BeamSearchAlgorithm, BeamSearchAlgorithm,
DecodingAlgorithm, DecodingAlgorithm,
GreedyAlgorithm, GreedyAlgorithm,
NucleusAlgorithm, NucleusAlgorithm,
TopKAlgorithm, TopKAlgorithm,
) )
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
logger = get_logger(__file__) logger = get_logger(__file__)

@ -7,7 +7,7 @@ import torch.nn as nn
from hivemind import get_logger, use_hivemind_log_handler from hivemind import get_logger, use_hivemind_log_handler
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from src.bloom.model import ( from petals.bloom.model import (
BloomConfig, BloomConfig,
BloomForCausalLM, BloomForCausalLM,
BloomForSequenceClassification, BloomForSequenceClassification,
@ -15,10 +15,10 @@ from src.bloom.model import (
BloomPreTrainedModel, BloomPreTrainedModel,
LMHead, LMHead,
) )
from src.client.remote_generation import RemoteGenerationMixin from petals.client.remote_generation import RemoteGenerationMixin
from src.client.remote_sequential import RemoteSequential from petals.client.remote_sequential import RemoteSequential
from src.constants import PUBLIC_INITIAL_PEERS from petals.constants import PUBLIC_INITIAL_PEERS
from src.utils.misc import DUMMY from petals.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -7,12 +7,12 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn from torch import nn
import src import petals.client
from src.client.inference_session import InferenceSession from petals.client.inference_session import InferenceSession
from src.client.sequence_manager import RemoteSequenceManager from petals.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER from petals.data_structures import UID_DELIMITER
from src.utils.misc import DUMMY from petals.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
def __init__( def __init__(
self, self,
config: src.DistributedBloomConfig, config: petals.client.DistributedBloomConfig,
dht: DHT, dht: DHT,
dht_prefix: Optional[str] = None, dht_prefix: Optional[str] = None,
p2p: Optional[P2P] = None, p2p: Optional[P2P] = None,

@ -9,10 +9,9 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2 from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.client.spending_policy import NoSpendingPolicy import petals.dht_utils
from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from src.dht_utils import get_remote_module_infos from petals.server.handler import TransformerConnectionHandler
from src.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
@ -88,7 +87,9 @@ class RemoteSequenceManager:
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
def update_block_infos_(self): def update_block_infos_(self):
new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf")) new_block_infos = petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, expiration_time=float("inf")
)
assert len(new_block_infos) == len(self.block_uids) assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None: if info is None:

@ -11,11 +11,11 @@ import torch
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from src.client.sequence_manager import RemoteSequenceManager from petals.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from src.server.handler import TransformerConnectionHandler from petals.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__) logger = get_logger(__file__)

@ -12,8 +12,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import PeerID from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src import petals.client
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
@ -76,10 +76,10 @@ def get_remote_sequence(
dht: DHT, dht: DHT,
start: int, start: int,
stop: int, stop: int,
config: src.DistributedBloomConfig, config: petals.client.DistributedBloomConfig,
dht_prefix: Optional[str] = None, dht_prefix: Optional[str] = None,
return_future: bool = False, return_future: bool = False,
) -> Union[src.RemoteSequential, MPFuture]: ) -> Union[petals.client.RemoteSequential, MPFuture]:
return RemoteExpertWorker.run_coroutine( return RemoteExpertWorker.run_coroutine(
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
) )
@ -89,22 +89,22 @@ async def _get_remote_sequence(
dht: DHT, dht: DHT,
start: int, start: int,
stop: int, stop: int,
config: src.DistributedBloomConfig, config: petals.client.DistributedBloomConfig,
dht_prefix: Optional[str] = None, dht_prefix: Optional[str] = None,
) -> src.RemoteSequential: ) -> petals.client.RemoteSequential:
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)] uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
p2p = await dht.replicate_p2p() p2p = await dht.replicate_p2p()
manager = src.RemoteSequenceManager(dht, uids, p2p) manager = petals.client.RemoteSequenceManager(dht, uids, p2p)
return src.RemoteSequential(config, dht, dht_prefix, p2p, manager) return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
def get_remote_module( def get_remote_module(
dht: DHT, dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]], uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig, config: petals.client.DistributedBloomConfig,
dht_prefix: Optional[str] = None, dht_prefix: Optional[str] = None,
return_future: bool = False, return_future: bool = False,
) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]: ) -> Union[Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]], MPFuture]:
""" """
:param uid_or_uids: find one or more modules with these ids from across the DHT :param uid_or_uids: find one or more modules with these ids from across the DHT
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME) :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@ -119,15 +119,16 @@ def get_remote_module(
async def _get_remote_module( async def _get_remote_module(
dht: DHT, dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]], uid_or_uids: Union[ModuleUID, List[ModuleUID]],
config: src.DistributedBloomConfig, config: petals.client.DistributedBloomConfig,
dht_prefix: Optional[str] = None, dht_prefix: Optional[str] = None,
) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]: ) -> Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]]:
single_uid = isinstance(uid_or_uids, ModuleUID) single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids uids = [uid_or_uids] if single_uid else uid_or_uids
p2p = await dht.replicate_p2p() p2p = await dht.replicate_p2p()
managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids) managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
modules = [ modules = [
src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
for m in managers
] ]
return modules[0] if single_uid else modules return modules[0] if single_uid else modules

@ -6,10 +6,10 @@ from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger from hivemind.utils import get_logger
from src.bloom.from_pretrained import BloomBlock from petals.bloom.from_pretrained import BloomBlock
from src.server.cache import MemoryCache from petals.server.cache import MemoryCache
from src.server.task_pool import PrioritizedTaskPool from petals.server.task_pool import PrioritizedTaskPool
from src.utils.misc import is_dummy from petals.utils.misc import is_dummy
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
from hivemind import PeerID, get_logger from hivemind import PeerID, get_logger
from src.data_structures import RemoteModuleInfo, ServerState from petals.data_structures import RemoteModuleInfo, ServerState
__all__ = ["choose_best_blocks", "should_choose_other_blocks"] __all__ = ["choose_best_blocks", "should_choose_other_blocks"]

@ -21,11 +21,11 @@ from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming from hivemind.utils.streaming import split_for_streaming
from src.data_structures import CHAIN_DELIMITER, ModuleUID from petals.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import TransformerBackend from petals.server.backend import TransformerBackend
from src.server.task_pool import PrioritizedTaskPool from petals.server.task_pool import PrioritizedTaskPool
from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from src.utils.misc import DUMMY, is_dummy from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__) logger = get_logger(__file__)

@ -16,17 +16,17 @@ from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import BloomConfig, declare_active_modules from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block from petals.bloom.model import BloomConfig
from src.constants import PUBLIC_INITIAL_PEERS from petals.constants import PUBLIC_INITIAL_PEERS
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from src.dht_utils import get_remote_module_infos from petals.dht_utils import declare_active_modules, get_remote_module_infos
from src.server import block_selection from petals.server import block_selection
from src.server.backend import TransformerBackend from petals.server.backend import TransformerBackend
from src.server.cache import MemoryCache from petals.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler from petals.server.handler import TransformerConnectionHandler
from src.server.throughput import get_host_throughput from petals.server.throughput import get_host_throughput
from src.utils.convert_8bit import replace_8bit_linear from petals.utils.convert_8bit import replace_8bit_linear
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -11,19 +11,16 @@ from typing import Dict, Union
import torch import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import project_name from petals.bloom.block import BloomBlock
from src.bloom.block import BloomBlock from petals.bloom.model import BloomConfig
from src.bloom.model import BloomConfig from petals.bloom.ops import build_alibi_tensor
from src.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json") DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput.json")
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock") DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
@dataclass @dataclass
@ -90,7 +87,7 @@ def measure_throughput_info() -> ThroughputInfo:
def measure_network_rps(config: BloomConfig) -> float: def measure_network_rps(config: BloomConfig) -> float:
proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True) proc = subprocess.run("python3 -m petals.cli.speed_test --json", shell=True, capture_output=True)
if proc.returncode != 0: if proc.returncode != 0:
raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})") raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
network_info = json.loads(proc.stdout) network_info = json.loads(proc.stdout)

@ -3,16 +3,13 @@ import random
import hivemind import hivemind
import pytest import pytest
import torch import torch
import transformers
from hivemind import P2PHandlerError
from test_utils import * from test_utils import *
import src from petals.bloom.from_pretrained import load_pretrained_block
from src import DistributedBloomConfig from petals.client import DistributedBloomConfig
from src.bloom.from_pretrained import load_pretrained_block from petals.client.remote_sequential import RemoteTransformerBlock
from src.client.remote_sequential import RemoteTransformerBlock from petals.data_structures import UID_DELIMITER
from src.data_structures import UID_DELIMITER from petals.dht_utils import get_remote_module
from src.dht_utils import get_remote_module
@pytest.mark.forked @pytest.mark.forked

@ -9,16 +9,16 @@ import pytest
import torch import torch
from test_utils import * from test_utils import *
import src from petals.bloom.from_pretrained import load_pretrained_block
from src.bloom.from_pretrained import load_pretrained_block from petals.client import DistributedBloomConfig
from src.client.remote_sequential import RemoteSequential from petals.client.remote_sequential import RemoteSequential
from src.dht_utils import get_remote_sequence from petals.dht_utils import get_remote_sequence
@pytest.mark.forked @pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME) config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_blocks = get_remote_sequence(dht, 3, 6, config) remote_blocks = get_remote_sequence(dht, 3, 6, config)
assert isinstance(remote_blocks, RemoteSequential) assert isinstance(remote_blocks, RemoteSequential)
@ -47,7 +47,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked @pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4): def test_chained_inference_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME) config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
remote_blocks = get_remote_sequence(dht, 3, 5, config) remote_blocks = get_remote_sequence(dht, 3, 5, config)
assert isinstance(remote_blocks, RemoteSequential) assert isinstance(remote_blocks, RemoteSequential)

@ -5,8 +5,8 @@ from hivemind import get_logger, use_hivemind_log_handler
from test_utils import * from test_utils import *
from transformers.generation_utils import BeamSearchScorer from transformers.generation_utils import BeamSearchScorer
from src.bloom.model import BloomForCausalLM from petals.bloom.model import BloomForCausalLM
from src.client.remote_model import DistributedBloomForCausalLM from petals.client.remote_model import DistributedBloomForCausalLM
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

@ -4,8 +4,8 @@ import time
import pytest import pytest
import torch import torch
from src.server.runtime import Runtime from petals.server.runtime import Runtime
from src.server.task_pool import PrioritizedTaskPool from petals.server.task_pool import PrioritizedTaskPool
@pytest.mark.forked @pytest.mark.forked

@ -3,9 +3,9 @@ import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler from hivemind import DHT, get_logger, use_hivemind_log_handler
from test_utils import * from test_utils import *
from src import RemoteSequential from petals.bloom.from_pretrained import load_pretrained_block
from src.bloom.from_pretrained import load_pretrained_block from petals.client import RemoteSequential
from src.client.remote_model import DistributedBloomConfig from petals.client.remote_model import DistributedBloomConfig
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)

Loading…
Cancel
Save