Merge branch 'main' into vadi2-patch-1

pull/438/head
Alexander Borzunov 10 months ago committed by GitHub
commit 92ee7ecb8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,10 +10,20 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
include:
- { model: 'bigscience/bloom-560m', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', python-version: '3.9' }
- { model: 'bigscience/bloom-560m', python-version: '3.10' }
- { model: 'bigscience/bloom-560m', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
fail-fast: false
timeout-minutes: 15
steps:
- name: Increase swap space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
@ -31,44 +41,77 @@ jobs:
pip install .[dev]
- name: Test
run: |
export MODEL_NAME=bigscience/bloom-560m
export REF_NAME=bigscience/bloom-560m
export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--adapters $ADAPTER_NAME &> server1.log &
SERVER1_PID=$!
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
# [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
bash -c 'while true; do free -h && sleep 30s; done' &
RAM_WATCH_PID=$!
sleep 5 # wait for the first server to initialize DHT
# [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
BOOTSTRAP_PID=$!
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
SERVER2_PID=$!
sleep 5 # wait for DHT init
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
sleep 10 # wait for the 1st server to choose blocks
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
--identity_path tests/server2.id \
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
SERVER2_PID=$!
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
SERVER3_PID=$!
# ^-- chunking test
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)
tail -n 100 -f server*.log &
sleep 5 # wait for the log files to appear
tail -n 100 -f bootstrap.log server*.log &
LOGGER_PID=$!
sleep 30 # wait for servers to download layers
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
# [Step 3] Run PyTest
pytest tests --durations=0 --durations-min=1.0 -v
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests
# [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3
python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --n_steps 1
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
# [Step 5] Clean up
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests
kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
echo "Done!"

@ -8,7 +8,7 @@
<br>
</p>
Generate text with distributed [LLaMA 2 (70B)](https://huggingface.co/meta-llama/Llama-2-70b-hf), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and finetune them for your own tasks &mdash; right from your desktop computer or Google Colab:
```python
from transformers import AutoTokenizer
@ -37,7 +37,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/J29mCBNBvm)!
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
### Connect your GPU and increase Petals capacity
@ -48,7 +48,7 @@ Petals is a community-run system &mdash; we rely on people sharing their GPUs. Y
```bash
conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia
pip install git+https://github.com/bigscience-workshop/petals
python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16
python -m petals.cli.run_server stabilityai/StableBeluga2
```
🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
@ -57,7 +57,7 @@ python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16
```bash
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16
python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2
```
These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
@ -68,7 +68,7 @@ These commands will host a part of [Stable Beluga 2](https://huggingface.co/stab
python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE
```
💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)!
💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
@ -96,8 +96,8 @@ Learning more:
## How does it work?
- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning.
- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
<p align="center">

@ -15,15 +15,15 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--n_steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
if args.n_processes == "n_gpus":

@ -16,13 +16,13 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
if args.n_processes == "n_gpus":

@ -15,18 +15,18 @@ logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--task", type=str, default="cls")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--pre_seq_len", type=int, default=16)
parser.add_argument("--n_steps", type=int, default=10)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
parser.add_argument("--task", type=str, default="cls", help="Training task type")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
assert args.task in ["cls", "causal_lm"]

@ -92,9 +92,6 @@
},
"outputs": [],
"source": [
"# Choose a model you'd like to prompt-tune. We recommend starting with\n",
"# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n",
"# The code below uses LLaMA-65B.\n",
"MODEL_NAME = \"enoch/llama-65b-hf\"\n",
"\n",
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",

@ -32,7 +32,7 @@ packages = find:
python_requires = >=3.8
install_requires =
torch>=1.12
bitsandbytes==0.40.1.post1
bitsandbytes==0.41.1
accelerate>=0.20.3,<0.21.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3

@ -7,8 +7,8 @@ This script may be used for launching lightweight CPU machines serving as bootst
This may be eventually merged to the hivemind upstream.
"""
import argparse
import time
from argparse import ArgumentParser
from secrets import token_hex
from hivemind.dht import DHT, DHTNode
@ -35,7 +35,7 @@ async def report_status(dht: DHT, node: DHTNode):
def main():
parser = ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--initial_peers",
nargs="*",
@ -73,7 +73,9 @@ def main():
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
)
parser.add_argument(
"--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls"
"--use_auto_relay",
action="store_true",
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"

@ -122,7 +122,7 @@ def main():
help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS,
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
group.add_argument('--new_swarm', action='store_true',
help='Start a new private swarm (i.e., do not connect to any initial peers)')
@ -158,7 +158,7 @@ def main():
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
parser.add_argument("--adapters", nargs='+', default=(),
parser.add_argument("--adapters", nargs='*', default=(),
help="List of pre-loaded LoRA adapters that can be used for inference or training")
# fmt:on

@ -50,7 +50,7 @@ class SequenceManagerConfig:
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 5 # max servers to ping from each sequence side, per update
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
@ -293,6 +293,8 @@ class RemoteSequenceManager:
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
client_server_rtts = self.ping_aggregator.to_dict()
span_sequence = []
current_index = start_index
while current_index < end_index:
@ -300,7 +302,13 @@ class RemoteSequenceManager:
if not candidate_spans:
raise MissingBlocksError(current_index)
span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
# We choose longer servers to minimize the number of hops but leave some randomization
# to distribute the load. We also exclude servers known to be unreachable.
eps = 1e-6
span_weights = np.array(
[span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
dtype=np.float64,
)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
assert chosen_span.start <= current_index < chosen_span.end
@ -361,9 +369,13 @@ class RemoteSequenceManager:
self.state.sequence_info.update_(new_block_infos)
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
middle_servers = [
span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
]
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)

@ -0,0 +1,195 @@
"""
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
"""
from __future__ import annotations
from typing import AsyncIterator, Optional, Sequence, Tuple, Union
import torch
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.proto import runtime_pb2
from hivemind.utils.nested import nested_flatten
from petals.data_structures import InferenceMetadata
from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def run_rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
async def iterate_rpc_inference(
requested_uids: Sequence[ExpertUID],
requested_backends: Sequence[TransformerBackend],
active_adapter: Optional[str],
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
cache_handles: Sequence[Sequence[Handle]],
max_length: int,
prioritizer: TaskPrioritizerBase,
points: int,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
assert len(cache_handles) == len(requested_backends)
prefix_length = 0
point_per_piece = points / max_length if max_length > 0 else 0.0
async for request, step_metadata in input_iterator:
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
priority = prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
can_push = not has_prompts
yield output_tensors, can_push
# prepare for next step
prefix_length += length_increment

@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
if dtype not in ("auto", None):
return dtype
if config.torch_dtype not in ("auto", None):
if config.torch_dtype not in ("auto", None, torch.float32):
# If config specifies float32, we override it to the default dtype below
return config.torch_dtype
return torch.bfloat16

@ -6,7 +6,7 @@ import multiprocessing as mp
import sys
from enum import Enum
from itertools import chain
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
import torch
from async_timeout import timeout
@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__name__)
@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
if not requested_uids:
@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler):
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
)
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
prefix_length = 0
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
first_request = request
background_tasks = set()
async for request, metadata in self._iterate_inference_steps(
first_request, requests, session_id, requested_uids, context
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
input_iterator=self._iterate_inference_steps(
request, requests, session_id, requested_uids, context
),
cache_handles=cache_handles,
max_length=max_length,
prioritizer=self._prioritizer,
points=points,
):
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
priority = self._prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
if not has_prompts:
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
# prepare for next step
prefix_length += length_increment
finally:
self._log_request("rpc_inference.close", requested_uids, context)
@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await _rpc_backward(
grads = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await _rpc_backward(
grads = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler):
result.update(block_info)
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
async def _rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def _rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape

@ -78,7 +78,7 @@ class Server:
sender_threads: int = 1,
balance_quality: float = 0.75,
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5,
mean_block_selection_delay: float = 5,
token: Optional[Union[str, bool]] = None,
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
@ -201,6 +201,8 @@ class Server:
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()
if num_blocks is not None:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
if block_indices is not None:
try:
first_block_index, last_block_index = block_indices.split(":")
@ -295,7 +297,7 @@ class Server:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
logger.info(
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
f"Server will fill your GPU memory with {num_blocks} transformer blocks. "
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
)
return num_blocks
@ -698,7 +700,9 @@ class ModuleAnnouncerThread(threading.Thread):
delay = self.update_period - (time.perf_counter() - start_time)
if delay < 0:
logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
logger.warning(
f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
)
self.trigger.wait(max(delay, 0))
self.trigger.clear()

Binary file not shown.

@ -29,6 +29,9 @@ def test_bnb_not_imported_when_unnecessary():
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
if tensor_parallel and config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps(
config,

@ -3,14 +3,14 @@ import random
import pytest
import torch
from petals import DistributedBloomConfig, RemoteSequential
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
for block_index in random.sample(range(config.num_hidden_layers), 3):

@ -7,7 +7,7 @@
import pytest
import torch
from petals import DistributedBloomConfig
from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@ -15,7 +15,7 @@ from test_utils import *
@pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
assert isinstance(remote_blocks, RemoteSequential)
@ -43,7 +43,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
inputs = torch.randn(1, 8, config.hidden_size)

@ -3,29 +3,31 @@ import pytest
import torch
import transformers
from hivemind import get_logger
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM
from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
from petals import DistributedBloomForCausalLM
from petals import AutoDistributedModelForCausalLM
from test_utils import *
logger = get_logger(__name__)
@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME,
initial_peers=INITIAL_PEERS,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None,
)
config = model.config
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.num_hidden_layers
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
@ -63,7 +65,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
del model, embs, recurrent_outputs
if REF_NAME:
ref_model = transformers.BloomForCausalLM.from_pretrained(
ref_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
@ -86,27 +88,29 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
@pytest.mark.forked
def test_greedy_generation(max_new_tokens=4):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
def test_greedy_generation(tokenizer, max_new_tokens=4):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
)
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
)
hf_outputs_batch = BloomForCausalLM.greedy_search(
hf_outputs_batch = HfGenerationMixin.greedy_search(
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
)
assert torch.allclose(
@ -117,13 +121,13 @@ def test_greedy_generation(max_new_tokens=4):
@pytest.mark.forked
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
def test_sampling(sampling_options, max_new_tokens=4):
def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
torch.manual_seed(0)
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng():
remote_outputs = model.generate(
@ -133,7 +137,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs = BloomForCausalLM.sample(
hf_outputs = HfGenerationMixin.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
)
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
@ -149,7 +153,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs_batch = BloomForCausalLM.sample(
hf_outputs_batch = HfGenerationMixin.sample(
model,
input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens,
@ -161,10 +165,9 @@ def test_sampling(sampling_options, max_new_tokens=4):
@pytest.mark.forked
def test_beam_search_generation(max_new_tokens=4, num_beams=2):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
text = "A cat sat on a mat"
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
@ -181,7 +184,7 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
do_early_stopping=False,
)
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = BloomForCausalLM.beam_search(
hf_outputs = HfGenerationMixin.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"

@ -4,7 +4,7 @@ import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2
from petals import DistributedBloomConfig
from petals import AutoDistributedConfig
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER
from petals.server.from_pretrained import load_pretrained_block
@ -15,7 +15,7 @@ logger = get_logger(__name__)
@pytest.mark.forked
def test_remote_sequential():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
grad_proj = torch.randn(1, 5, config.hidden_size)
@ -40,10 +40,10 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
@ -56,7 +56,7 @@ def test_remote_sequential():
(approx_outputs * grad_proj).sum().backward()
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01
absmax = abs(full_grad).max()
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
@ -87,7 +87,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
@pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)

@ -5,7 +5,7 @@ import pytest
import torch
from hivemind import DHT, get_logger
from petals import DistributedBloomConfig
from petals import AutoDistributedConfig
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER
from test_utils import *
@ -16,7 +16,7 @@ logger = get_logger(__name__)
@pytest.mark.forked
@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
def test_sequence_manager_basics(mode: str):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
sequential = RemoteSequential(config, dht=dht)
shutdown_evt = threading.Event()

@ -4,14 +4,16 @@ import hivemind
import pytest
import torch
from petals import DistributedBloomConfig, RemoteSequential
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import *
@pytest.mark.forked
def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)

@ -14,8 +14,11 @@ from test_utils import MODEL_NAME
@pytest.mark.parametrize("custom_config", [True, False])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
def test_tp_block(devices, custom_config):
block_index = random.randint(0, 10)
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
if model_config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
block_index = random.randint(0, 10)
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
tp_config = None

Loading…
Cancel
Save