pull/493/head
Danny Boy 9 months ago
commit 9027270240

@ -7,20 +7,21 @@ on:
jobs:
run-tests:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- { model: 'bigscience/bloom-560m', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', python-version: '3.9' }
- { model: 'bigscience/bloom-560m', python-version: '3.10' }
- { model: 'bigscience/bloom-560m', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 15
steps:
- name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
@ -47,12 +48,7 @@ jobs:
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
bash -c 'while true; do free -h && sleep 30s; done' &
RAM_WATCH_PID=$!
# [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
@ -61,7 +57,7 @@ jobs:
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
sleep 5 # wait for DHT init
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \
@ -95,11 +91,15 @@ jobs:
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
# [Step 3] Run PyTest
# [Step 2] Run PyTest
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
pytest tests --durations=0 --durations-min=1.0 -v
# [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
# [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3
@ -110,9 +110,7 @@ jobs:
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
# [Step 5] Clean up
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests
# [Step 4] Clean up
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
echo "Done!"

@ -14,14 +14,14 @@ Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
# Choose any model available at https://health.petals.dev
model_name = "petals-team/StableBeluga2"
# You can also use "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-70b-chat-hf",
# repos with Llama-65B, "bigscience/bloom", or "bigscience/bloomz"
# Connect to a distributed network hosting model layers
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet
# Run the model as if it were on your computer
inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
outputs = model.generate(inputs, max_new_tokens=5)
print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
@ -33,17 +33,15 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
📋 **Terms of use.** Make sure you follow the model license (see [Llama 2](https://bit.ly/llama2-license), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2/blob/main/LICENSE.txt), [Llama](https://bit.ly/llama-license), and [BLOOM](https://bit.ly/bloom-license)).
🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
### Connect your GPU and increase Petals capacity
## Connect your GPU and increase Petals capacity
Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models!
Petals is a community-run system — we rely on people sharing their GPUs. You can check out [available models](https://health.petals.dev) and help serving one of them! As an example, here is how to host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your GPU:
🐍 **Linux + Anaconda.** Run these commands:
🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
@ -51,48 +49,35 @@ pip install git+https://github.com/bigscience-workshop/petals
python -m petals.cli.run_server petals-team/StableBeluga2
```
🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
🐋 **Any OS + Docker.** Run our [Docker](https://www.docker.com) image:
🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
```bash
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
learningathome/petals:main \
python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
```
These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with Llama-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then use this command for `petals.cli.run_server`:
🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
```bash
python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE
brew install python
python3 -m pip install git+https://github.com/bigscience-workshop/petals
python3 -m petals.cli.run_server petals-team/StableBeluga2
```
💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
### Check out tutorials, examples, and more
Basic tutorials:
- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
- Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
<p align="center">
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
</p>
Useful tools and advanced guides:
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
- Launch your own swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
- Run a custom foundation model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
Learning more:
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
- Frequently asked questions: [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions)
- In-depth system description: [paper](https://arxiv.org/abs/2209.01188)
🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
## How does it work?
@ -105,23 +90,28 @@ Learning more:
</p>
<p align="center">
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
</p>
## Installation
## 📚 Tutorials, examples, and more
Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/distribution) on Linux:
Basic tutorials:
```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install git+https://github.com/bigscience-workshop/petals
```
- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
- Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
Useful tools:
- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
Advanced guides:
See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client).
- Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
- Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
## Benchmarks

@ -18,6 +18,7 @@ classifiers =
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Mathematics
Topic :: Scientific/Engineering :: Artificial Intelligence
@ -39,7 +40,7 @@ install_requires =
transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.9
hivemind @ git+https://github.com/learning-at-home/hivemind
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2

@ -1,7 +1,13 @@
import os
import platform
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
if platform.system() == "Darwin":
# Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
os.environ.setdefault("no_proxy", "*")
os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
import hivemind
import transformers
from packaging import version

@ -1,8 +1,10 @@
import argparse
import logging
import configargparse
import torch
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils import limits
from hivemind.utils.logging import get_logger
from humanfriendly import parse_size
@ -96,9 +98,9 @@ def main():
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--alloc_timeout', type=float, default=1,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
parser.add_argument('--max_alloc_timeout', type=float, default=600,
help="If the cache is full, the server will wait for memory to be freed up to this many seconds"
" before rejecting the request")
parser.add_argument('--revision', type=str, default=None,
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@ -127,9 +129,9 @@ def main():
group.add_argument('--new_swarm', action='store_true',
help='Start a new private swarm (i.e., do not connect to any initial peers)')
parser.add_argument('--increase_file_limit', action='store_true',
help='On *nix, this will increase the max number of processes '
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
parser.add_argument('--increase_file_limit', type=int, default=4096,
help='On *nix, increase the max number of files a server can open '
'before hitting "Too many open files" (set to zero to keep the system limit)')
parser.add_argument('--stats_report_interval', type=int, required=False,
help='Interval between two reports of batch processing performance statistics')
@ -185,8 +187,10 @@ def main():
args["startup_timeout"] = args.pop("daemon_startup_timeout")
if args.pop("increase_file_limit"):
increase_file_limit()
file_limit = args.pop("increase_file_limit")
if file_limit:
limits.logger.setLevel(logging.WARNING)
limits.increase_file_limit(file_limit, file_limit)
compression_type = args.pop("compression").upper()
compression = getattr(CompressionType, compression_type)
@ -207,6 +211,10 @@ def main():
validate_version()
if not torch.backends.openmp.is_available():
# Necessary to prevent the server from freezing after forks
torch.set_num_threads(1)
server = Server(
**args,
host_maddrs=host_maddrs,

@ -16,7 +16,7 @@ from transformers import PretrainedConfig
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
from petals.utils.misc import get_size_in_bytes, is_dummy
logger = get_logger(__name__)
@ -63,7 +63,7 @@ class TransformerBackend(ModuleBackend):
)
self.dtype = backend_dtype
self.dtype_bytes = torch.finfo(self.dtype).bits // 8
self.dtype_bytes = get_size_in_bytes(self.dtype)
self.shard_num_heads = []
for shard in self.module.module_shards:
for submodule in shard.modules():
@ -83,7 +83,7 @@ class TransformerBackend(ModuleBackend):
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""

@ -5,6 +5,7 @@ from accelerate import init_empty_weights
from transformers import PretrainedConfig
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
@ -37,7 +38,7 @@ def get_block_size(
if location == "memory":
if quant_type == QuantType.NONE:
dtype = resolve_block_dtype(config, dtype)
bytes_per_value = torch.finfo(dtype).bits // 8
bytes_per_value = get_size_in_bytes(dtype)
elif quant_type == QuantType.INT8:
bytes_per_value = 1
elif quant_type == QuantType.NF4:
@ -46,6 +47,6 @@ def get_block_size(
raise ValueError(f"Unsupported quant_type={quant_type}")
elif location == "disk":
dtype = resolve_block_dtype(config, "auto")
bytes_per_value = torch.finfo(dtype).bits // 8
bytes_per_value = get_size_in_bytes(dtype)
return round(n_params * bytes_per_value * (1 + eps))

@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
max_length = metadata.get("max_length")
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
args_structure = metadata.get("args_structure")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
@ -166,7 +167,9 @@ class TransformerConnectionHandler(ConnectionHandler):
batch_size = request.tensors[0].size[0] if request.tensors else 1
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
@ -528,14 +531,19 @@ class TransformerConnectionHandler(ConnectionHandler):
@contextlib.asynccontextmanager
async def _allocate_cache(
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
self,
backends: Sequence[TransformerBackend],
*,
batch_size: int,
max_length: int,
timeout: Optional[float],
) -> Sequence[Sequence[Handle]]:
"""
Allocate memory cache for all transformer blocks, return cache handle
:returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
"""
descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
yield nested_pack(handles, descriptors)
def _log_request(

@ -12,12 +12,13 @@ import os
import time
from typing import AsyncContextManager, Dict, Optional, Sequence
import hivemind
import async_timeout
import torch
from hivemind.utils import TensorDescriptor, get_logger
from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
from petals.data_structures import Handle
from petals.utils.asyncio import shield_and_wait
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
@ -25,11 +26,12 @@ logger = get_logger(__name__)
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.alloc_timeout = alloc_timeout
self.max_alloc_timeout = max_alloc_timeout
self._lock_metadata = mp.Lock()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
self.runtime_pid = os.getpid()
@ -46,6 +48,14 @@ class MemoryCache:
def current_size_bytes(self, value: int):
self._current_size.value = value
@property
def enqueued_size_bytes(self) -> int:
return self._enqueued_size.value
@enqueued_size_bytes.setter
def enqueued_size_bytes(self, value: int):
self._enqueued_size.value = value
@property
def bytes_left(self) -> int:
return self.max_size_bytes - self.current_size_bytes
@ -59,11 +69,14 @@ class MemoryCache:
self._handle_counter.value = value
@contextlib.asynccontextmanager
async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
async def allocate_cache(
self, *descriptors: TensorDescriptor, timeout: float
) -> AsyncContextManager[Sequence[Handle]]:
"""
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
:param descriptors: one or more tensors tensor of this size, dtype, etc
:param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
if not, it will count maximum tensor allocation across devices for the purposes of size limit
@ -73,6 +86,8 @@ class MemoryCache:
"""
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
if self.max_alloc_timeout is not None:
timeout = min(timeout, self.max_alloc_timeout)
max_alloc_size = self.get_allocation_size(*descriptors)
gib = 1024**3
@ -83,10 +98,10 @@ class MemoryCache:
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
)
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
try:
handles = await shield_and_wait(alloc_task)
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
yield handles
finally:
self._free(max_alloc_size, alloc_task)
@ -96,28 +111,59 @@ class MemoryCache:
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
alloc_size_by_device = {}
for descr in descriptors:
tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
return max(alloc_size_by_device.values())
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
async def _schedule_alloc(
self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]
) -> Sequence[Handle]:
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
"""
try:
async with self._wait_for_free_memory(alloc_size, timeout):
with self._lock_metadata:
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors))
return handles
except TimeoutError:
raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})")
@contextlib.asynccontextmanager
async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):
start_time = time.perf_counter()
loop = asyncio.get_event_loop()
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
with self._lock_metadata:
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors))
return handles
def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
self.enqueued_size_bytes += alloc_size
allocated = False
try:
context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
# contextlib.AsyncExitStack() is used as a null context here
async with context_manager:
if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:
raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
async with enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
if timeout == 0:
raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
elapsed_time = time.perf_counter() - start_time
remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None
await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
allocated = True
self.enqueued_size_bytes -= alloc_size
yield
except asyncio.TimeoutError:
raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
finally:
if not allocated:
self.enqueued_size_bytes -= alloc_size
def _free(self, alloc_size: int, alloc_task: asyncio.Task):
if alloc_task.exception() is not None:
return
handles = alloc_task.result()
@ -133,9 +179,10 @@ class MemoryCache:
raise AllocationFailed(
f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
)
timeout = timeout if timeout != float("inf") else None
deadline = None if timeout is None else time.perf_counter() + timeout
while self.current_size_bytes + allocated_size > self.max_size_bytes:
remaining_time = deadline - time.perf_counter() if timeout is not None else None
remaining_time = None if timeout is None else deadline - time.perf_counter()
if not self._memory_freed_event.wait(remaining_time):
raise AllocationFailed(
f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"

@ -140,7 +140,7 @@ class ReachabilityProtocol(ServicerBase):
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
ready.set_result(True)
logger.info("Reachability service started")
logger.debug("Reachability service started")
async with protocol.serve(common_p2p):
await protocol._stop.wait()

@ -9,7 +9,9 @@ import time
from typing import Dict, List, Optional, Sequence, Union
import hivemind
import psutil
import torch
import torch.mps
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
@ -31,6 +33,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.dht import declare_active_modules, get_remote_module_infos
from petals.utils.misc import get_size_in_bytes
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
from petals.utils.version import get_compatible_model_repo
@ -59,12 +62,12 @@ class Server:
min_batch_size: int = 1,
max_batch_size: Optional[int] = None,
max_chunk_size_bytes: int = 256 * 1024 * 1024,
max_alloc_timeout: float = 600,
attn_cache_tokens: Optional[int] = None,
torch_dtype: str = "auto",
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
alloc_timeout: float = 5,
device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
@ -153,13 +156,25 @@ class Server:
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device(device.type, index=0)
self.device = device
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
if device.type == "cpu" and torch_dtype == torch.float16:
raise ValueError(
f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
)
if device.type == "mps" and torch_dtype == torch.bfloat16:
logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
torch_dtype = torch.float16
self.torch_dtype = torch_dtype
if tensor_parallel_devices is None:
@ -185,13 +200,14 @@ class Server:
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.max_chunk_size_bytes = max_chunk_size_bytes
self.max_alloc_timeout = max_alloc_timeout
# For attention cache in GPU or RAM
if attn_cache_tokens is None:
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
# For disk cache
self.cache_dir = cache_dir
@ -217,8 +233,6 @@ class Server:
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.alloc_timeout = alloc_timeout
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput_info = get_server_throughput(
@ -253,13 +267,14 @@ class Server:
self.stop = threading.Event()
def _choose_num_blocks(self) -> int:
assert self.device.type == "cuda", (
assert self.device.type in ("cuda", "mps"), (
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
"CPU-only servers in the public swarm are discouraged since they are much slower"
)
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
if num_devices > 1:
assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
memory_per_device = tuple(
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
)
@ -270,8 +285,10 @@ class Server:
"Please launch individual servers on each GPU or set --num_blocks manually to "
"override this exception."
)
else:
elif self.device.type == "cuda":
total_memory = torch.cuda.get_device_properties(self.device).total_memory
else:
total_memory = psutil.virtual_memory().total
gib = 1024**3
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
@ -311,13 +328,13 @@ class Server:
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
alloc_timeout=self.alloc_timeout,
server_info=self.server_info,
block_indices=block_indices,
num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
max_chunk_size_bytes=self.max_chunk_size_bytes,
max_alloc_timeout=self.max_alloc_timeout,
inference_max_length=self.inference_max_length,
torch_dtype=self.torch_dtype,
cache_dir=self.cache_dir,
@ -373,6 +390,8 @@ class Server:
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
f"{reserved_vram / gib:.1f} GiB reserved memory"
)
elif self.device.type == "mps":
torch.mps.empty_cache()
def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
@ -413,12 +432,12 @@ class ModuleContainer(threading.Thread):
converted_model_name_or_path: str,
block_config: PretrainedConfig,
attn_cache_bytes: int,
alloc_timeout: float,
server_info: ServerInfo,
block_indices: List[int],
min_batch_size: int,
max_batch_size: int,
max_chunk_size_bytes: int,
max_alloc_timeout: float,
torch_dtype: torch.dtype,
cache_dir: str,
max_disk_space: int,
@ -434,7 +453,7 @@ class ModuleContainer(threading.Thread):
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread(
@ -663,7 +682,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.server_info = server_info
self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
self.bytes_per_token //= block_config.num_key_value_groups
self.update_period = update_period

@ -9,6 +9,7 @@ from pathlib import Path
from typing import Dict, Optional, Sequence, Union
import torch
import torch.mps
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
@ -207,14 +208,12 @@ def measure_compute_rps(
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
if device.type == "cuda":
torch.cuda.synchronize(device)
synchronize(device)
start_time = time.perf_counter()
for step in range(n_steps):
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
if device.type == "cuda":
torch.cuda.synchronize(device)
synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
@ -230,8 +229,15 @@ def measure_compute_rps(
return device_rps
def synchronize(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize(device)
elif device.type == "mps":
torch.mps.synchronize()
def get_device_name(device: torch.device) -> str:
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

@ -9,6 +9,16 @@ def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4}
def get_size_in_bytes(dtype: torch.dtype) -> int:
if dtype in SPECIAL_DTYPE_SIZES:
return SPECIAL_DTYPE_SIZES[dtype]
get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8
def docstring_from(source):
def add_docstring(dest):
dest.__doc__ = source.__doc__

@ -20,6 +20,7 @@ from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
@ -285,5 +286,5 @@ def estimate_adapter_memory_per_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
)
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
return adapter_parameters * bytes_per_parameter

@ -0,0 +1,184 @@
import asyncio
import multiprocessing as mp
import random
import time
from typing import Optional
import pytest
import pytest_asyncio # make sure the module exists; otherwise the test will be skipped
import torch
from hivemind import TensorDescriptor
from petals.server.memory_cache import AllocationFailed, MemoryCache
from petals.utils.misc import get_size_in_bytes
def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
if dtype is None:
dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
elem_size_bytes = get_size_in_bytes(dtype)
descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
return descr
@pytest.mark.asyncio
async def test_cache_timeout():
cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
cache.runtime_pid += 1 # pretend we're another process
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
pass
async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
t_start = time.perf_counter()
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
pass
assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
pass
t_start = time.perf_counter()
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0): # exceeds max timeout
pass
assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
# test memory allocation when another task frees the memory
async def _klog_the_cache():
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
pass
large_alloc_task = asyncio.create_task(_klog_the_cache())
t_start = time.perf_counter()
await asyncio.sleep(0.05) # wait for large alloc to enqueue
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): # exceeds max timeout
pass # this memory should allocate once the background task clears the queue
assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
with pytest.raises(AllocationFailed):
await large_alloc_task
# test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
large_alloc_task = asyncio.create_task(_klog_the_cache())
t_start = time.perf_counter()
await asyncio.sleep(0.05) # wait for large alloc to enqueue
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
pass # this memory should allocate once the background task clears the queue
assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
with pytest.raises(AllocationFailed):
await large_alloc_task
@pytest.mark.asyncio
async def test_unlimited_timeout():
cache = MemoryCache(max_size_bytes=1024)
cache.runtime_pid += 1 # pretend we're another process
t_start = time.perf_counter()
async def _klog_the_cache():
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
await asyncio.sleep(0.5)
alloc_task = asyncio.create_task(_klog_the_cache())
await asyncio.sleep(0.1)
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
await alloc_task
assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
@pytest.mark.asyncio
async def test_cache_usage():
cache = MemoryCache(max_size_bytes=2048)
alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
with pytest.raises(AssertionError):
async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
pass # fails because cache must be allocated from another process
descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes
descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes
descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool)) # 33 bytes
descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64)) # 0 bytes
descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes
descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes
async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
loop = asyncio.get_event_loop()
async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
pipe_sender.send(handles)
await loop.run_in_executor(None, dealloc_event.wait)
async def _allocate_af():
alloc_event.wait()
allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
await allocate_a_task
allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache
await allocate_f_task
alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
alloc_process1.start()
async def _allocate_bcde():
alloc_event.wait()
await asyncio.sleep(0.1) # ensure that the other tensor is always allocated (and sent through pipe) first
allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
alloc_process2.start()
assert cache.current_size_bytes == 0
alloc_event.set()
(handle_a,) = pipe_receiver.recv()
handle_b, handle_c, handle_d = pipe_receiver.recv()
with cache.use_cache(handle_a) as (tensor_a,):
assert tensor_a.dtype == torch.uint8
tensor_a[2:5] = torch.tensor((42, 43, 44))
with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
tensor_a += 1
tensor_b[...] = -1.337
assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory
dealloc_bcd_event.set()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 768 # only tensor a should be allocated
with pytest.raises(KeyError):
with cache.use_cache(handle_a, handle_b):
pass # one of handles (c) is deallocated
with pytest.raises(KeyError):
with cache.use_cache(handle_d):
pass # handle_d is deallocated correctly, even though it is never used
with cache.use_cache(handle_a) as (tensor_a,):
assert tuple(tensor_a[2:5]) == (43, 44, 45)
dealloc_a_event.set()
(handle_e,) = pipe_receiver.recv() # e can finally be allocated
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate
with pytest.raises(KeyError):
with cache.use_cache(handle_a):
pass # tensor a is no longer allocated
with cache.use_cache(handle_e) as (tensor_e,):
assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
dealloc_e_event.set()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 1792 # only tensor f is still allocated
dealloc_f_event.set()
alloc_process1.join()
alloc_process2.join()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 0
assert cache.current_size_bytes == 0
assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"

@ -1,4 +1,5 @@
import multiprocessing as mp
import platform
import time
import pytest
@ -8,9 +9,30 @@ from hivemind.moe.server.runtime import Runtime
from petals.server.task_pool import PrioritizedTaskPool
def _submit_tasks(runtime_ready, pools, results_valid):
runtime_ready.wait()
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
@pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks")
@pytest.mark.forked
def test_priority_pools():
outputs_queue = mp.SimpleQueue()
runtime_ready = mp.Event()
results_valid = mp.Event()
def dummy_pool_func(x):
@ -31,27 +53,14 @@ def test_priority_pools():
PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
)
# Simulate requests coming from ConnectionHandlers
proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
proc.start()
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
runtime.ready = runtime_ready
runtime.start()
def process_tasks():
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
proc = mp.Process(target=process_tasks)
proc.start()
proc.join()
assert results_valid.is_set()
@ -69,3 +78,5 @@ def test_priority_pools():
# 3 - task with priority 2 from pool A
# 4 - task with priority 10 from pool A
# 7 - task with priority 11 from pool B
runtime.shutdown()

Loading…
Cancel
Save