Add local tensor-parallel fwd/bwd (#143)

This pull request adds an option to run Petals server on multiple local GPUs. It uses https://github.com/BlackSamorez/tensor_parallel

- 8bit approximation error same as in main (mean~=2% q0.9~=5%)
    - TP=1, 2, 3 (see screenshots above)
- forward, grad w.r.t. input and inference exact match with main with TP=1
- `>=`80% GPU utilization with 3x 1080ti, batch = 8 tokens
- throughput measured with and without TP
- TP on 1080Tis has near-linear speedup comparable to the benchmarks (see first message)


Co-authored-by: Iaroslav Lisniak <yalisnyak@nes.ru>
Co-authored-by: Andrei Panferov <andrei@blacksamorez.ru>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/175/head
justheuristic 1 year ago committed by GitHub
parent 779959bc70
commit ae9e71fe8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -86,16 +86,16 @@ jobs:
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
SERVER3_PID=$!
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
SERVER4_PID=$!
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log &
--initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu --torch_dtype float32 &> server5.log &
SERVER5_PID=$!
tail -n 100 -f server*.log &

@ -39,6 +39,7 @@ install_requires =
protobuf>=3.20.3,<4.0dev
speedtest-cli==2.1.3
hivemind==1.1.3
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2

@ -129,8 +129,12 @@ def main():
parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', type=str, default=None,
help="Convert the loaded model into mixed-8bit quantized model. "
help="Convert the loaded transformer blocks into mixed-8bit quantized model. "
"Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
help=
"Split each block between the specified GPUs such that each device holds a portion of every "
"weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
parser.add_argument("--skip_reachability_check", action='store_true',
help="Skip checking this server's reachability via health.petals.ml "

@ -1,9 +1,14 @@
from __future__ import annotations
import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict
from typing import Any, Dict, Tuple
from hivemind import PeerID
from petals.server.memory_cache import Handle
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
@ -39,3 +44,9 @@ class RemoteSpanInfo:
RPCInfo = Dict[str, Any]
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:
prefix_length: int
cache_handles: Tuple[Handle, ...]

@ -1,12 +1,19 @@
"""Code for serving bloom blocks via hivemind-server"""
from __future__ import annotations
from itertools import chain
from typing import Any, Dict, Sequence, Tuple
import torch
from hivemind import BatchTensorDescriptor
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from tensor_parallel import TensorParallel
from tensor_parallel.tensor_parallel import PerDeviceTensors
from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.bloom.block import WrappedBloomBlock
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
@ -17,9 +24,10 @@ logger = get_logger(__file__)
class TransformerBackend(ModuleBackend):
"""A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.module, WrappedBloomBlock)
assert isinstance(self.module, TensorParallel)
self.config = config
self.memory_cache = memory_cache
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
@ -27,18 +35,26 @@ class TransformerBackend(ModuleBackend):
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
)
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
)
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
)
assert backend_dtype is not None
self.dtype = backend_dtype
self.shard_num_heads = []
for shard in self.module.module_shards:
for submodule in shard.modules():
if isinstance(submodule, BloomAttention):
self.shard_num_heads.append(submodule.num_heads)
assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
self.inference_schema = (
(
*self.args_schema,
@ -48,44 +64,60 @@ class TransformerBackend(ModuleBackend):
self.kwargs_schema,
)
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""
head_dim = self.config.hidden_size // self.config.n_head
cache_tensors = []
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
cache_tensors.extend((keys, values))
return cache_tensors
def inference_step(
self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
with torch.inference_mode():
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
cache_handle, rel_index, prefix_length = map(int, cache_metadata[0])
with self.memory_cache.use_cache(cache_handle) as cache:
batch_size = cache.shape[2]
max_length = cache.shape[-1] // (head_dim * num_heads)
assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4
if not is_dummy(hypo_ids):
assert hypo_ids.shape[0] == batch_size
cache[rel_index, :, :] = cache[rel_index, :, hypo_ids] # in-place reorder cache by hypo ids
key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length)
value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim)
key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim]
logger.debug(
f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}"
)
hidden_states, (new_key, new_value) = self.module.forward(
hidden_states, layer_past=(key_past, value_past), use_cache=True
)
new_length = new_key.shape[-1]
assert new_length > prefix_length
assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0]
assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length
new_key = new_key.view(batch_size, num_heads, head_dim, -1)
new_value = new_value.view(batch_size, num_heads, -1, head_dim)
key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (hidden_states,)
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
if not is_dummy(hypo_ids):
for cache_tensor in cache_tensors:
cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids
def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
for i in range(len(key_cache)):
key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim]
layer_past = tuple(chain(*zip(key_cache, value_cache)))
return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
def _update_cache_inplace(
self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
):
"""Writes new key/value tensors back into cache, works in-place"""
_batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
new_key = new_key.view(*cache_key.shape[:3], new_length)
cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool

@ -1,5 +1,8 @@
from __future__ import annotations
import asyncio
import contextlib
from itertools import chain
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
@ -8,10 +11,10 @@ from hivemind import (
DHT,
MSGPackSerializer,
P2PContext,
TensorDescriptor,
deserialize_tensor_stream,
deserialize_torch_tensor,
nested_flatten,
nested_pack,
serialize_torch_tensor,
)
from hivemind.moe.server.connection_handler import ConnectionHandler
@ -21,8 +24,9 @@ from hivemind.utils.asyncio import amap_in_executor, anext
from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
from petals.data_structures import CHAIN_DELIMITER, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
@ -122,17 +126,12 @@ class TransformerConnectionHandler(ConnectionHandler):
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
cache_metadata = torch.tensor(
[[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64
) # [cache_handle, rel_index, prefix_length]
prefix_length = 0
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle:
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
hidden_states, prompts, hypo_ids = [
deserialize_torch_tensor(tensor) for tensor in request.tensors
]
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
@ -155,16 +154,14 @@ class TransformerConnectionHandler(ConnectionHandler):
)
# run request tensors through all requested modules, update caches
for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)):
for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
if hidden_states.numel() == 0:
continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
cache_metadata[:] = torch.tensor(
[cache_handle, rel_index, prefix_length], dtype=torch.int64
)
metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles))
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
@ -175,7 +172,6 @@ class TransformerConnectionHandler(ConnectionHandler):
backend.inference_pool, PrioritizedTaskPool
), "petals support only prioritized pools"
priority = self._prioritizer.prioritize(
cache_metadata,
hidden_states,
hypo_ids,
points=point_per_piece / len(requested_backends),
@ -183,7 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler):
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, cache_metadata, priority=priority
hidden_states, hypo_ids, metadata, priority=priority
)
# serialize and send last layer outputs
@ -355,28 +351,14 @@ class TransformerConnectionHandler(ConnectionHandler):
@contextlib.asynccontextmanager
async def _allocate_cache(
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
) -> Sequence[int]:
"""Allocate memory cache for all transformer blocks, return cache handle"""
n_blocks = len(backends)
backend = backends[0]
n_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype)
alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
gib = 1024**3
cur_size = backend.memory_cache.current_size_bytes
max_size = backend.memory_cache.max_size_bytes
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
logger.info(
f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), "
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
)
async with backend.memory_cache.allocate_cache(descr) as handle:
logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
yield handle
) -> Sequence[Sequence[Handle, ...]]:
"""
Allocate memory cache for all transformer blocks, return cache handle
:returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
"""
descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
yield nested_pack(handles, descriptors)
def _log_request(
self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None

@ -10,7 +10,7 @@ import ctypes
import multiprocessing as mp
import os
import time
from typing import AsyncContextManager, Dict, Optional, Union
from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple
import hivemind
import torch
@ -26,10 +26,9 @@ Handle = int
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.alloc_timeout = alloc_timeout
self.device = device
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
@ -57,26 +56,48 @@ class MemoryCache:
self._handle_counter.value = value
@contextlib.asynccontextmanager
async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
"""
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
:param descr: allocate a tensor of this size, dtype, etc
:param descriptors: one or more tensors tensor of this size, dtype, etc
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
if not, it will count maximum tensor allocation across devices for the purposes of size limit
:note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
Furthermore, it can be called concurrently with at most one use_cache call in runtime.
"""
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
assert descr.device is None and descr
alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
max_alloc_size = self.get_allocation_size(*descriptors)
gib = 1024**3
cur_size, max_size = self.current_size_bytes, self.max_size_bytes
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
logger.info(
f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), "
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
)
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
try:
yield await shield_and_wait(alloc_task)
handles = await shield_and_wait(alloc_task)
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
yield handles
finally:
await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
@staticmethod
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
alloc_size_by_device = {}
for descr in descriptors:
tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
return max(alloc_size_by_device.values())
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
@ -87,11 +108,11 @@ class MemoryCache:
if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
handle = int(self.handle_counter)
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pipe_send.send((handle, descr))
return handle
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors))
return handles
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
"""
@ -102,10 +123,10 @@ class MemoryCache:
if alloc_task.exception() is not None:
return
handle = alloc_task.result()
handles = alloc_task.result()
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pipe_send.send((handle, None)) # signal runtime to free that handle
self._pipe_send.send((handles, None)) # signal runtime to free these handles
self.current_size_bytes -= alloc_size
self._memory_freed_event.set()
@ -125,11 +146,11 @@ class MemoryCache:
self._memory_freed_event.clear()
@contextlib.contextmanager
def use_cache(self, handle: Handle) -> torch.Tensor:
def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]:
"""
Return a tensor that was previously allocated with try_allocate_cache,
Return one or more tensors previously allocated with allocate_cache,
:note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
:note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism.
However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
"""
assert os.getpid() == self.runtime_pid
@ -138,20 +159,20 @@ class MemoryCache:
with self._lock_metadata:
# read creation/deletion requests from connection handlers
while self._pipe_recv.poll():
recv_handle, recv_data = self._pipe_recv.recv()
if isinstance(recv_data, TensorDescriptor):
self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
elif recv_data is None:
if recv_handle not in self._allocated_tensors:
logger.warning(
f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
)
self._allocated_tensors.pop(recv_handle, None)
else:
logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
yield self._allocated_tensors[handle]
recv_handles, recv_data = self._pipe_recv.recv()
if recv_data is not None: # create new tensors
assert len(recv_handles) == len(recv_data)
for handle, descr in zip(recv_handles, recv_data):
self._allocated_tensors[handle] = descr.make_zeros()
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
else: # delete tensors by handle
for handle in recv_handles:
if handle not in self._allocated_tensors:
logger.warning(
f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
)
self._allocated_tensors.pop(handle, None)
yield tuple(self._allocated_tensors[handle] for handle in handles)
class AllocationFailed(Exception):

@ -6,7 +6,7 @@ import multiprocessing as mp
import random
import threading
import time
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import psutil
@ -29,7 +29,7 @@ from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.throughput import get_host_throughput
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__file__)
@ -76,6 +76,7 @@ class Server:
mean_block_selection_delay: float = 2.5,
use_auth_token: Optional[str] = None,
load_in_8bit: Optional[bool] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
**kwargs,
):
@ -128,6 +129,8 @@ class Server:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device(device.type, index=0)
self.device = device
if isinstance(torch_dtype, str):
@ -141,6 +144,13 @@ class Server:
logger.info("Model weights will be loaded in 8-bit format")
self.load_in_8bit = load_in_8bit
if tensor_parallel_devices is None:
tensor_parallel_devices = (device,)
self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
if len(self.tensor_parallel_devices) > 1:
logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
check_device_balance(self.tensor_parallel_devices)
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()
@ -174,6 +184,7 @@ class Server:
device,
torch_dtype,
load_in_8bit=load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
)
@ -214,13 +225,28 @@ class Server:
self.converted_model_name_or_path == "bigscience/bloom-petals"
), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
if num_devices > 1:
memory_per_device = tuple(
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
)
total_memory = min(memory_per_device) * num_devices
if max(memory_per_device) / min(memory_per_device) > 1.5:
raise ValueError(
"GPU devices have highly uneven memory, which makes tensor parallelism inefficient. "
"Please launch individual servers on each GPU or set --num_blocks manually to "
"override this exception."
)
else:
total_memory = torch.cuda.get_device_properties(self.device).total_memory
total_memory = torch.cuda.get_device_properties(self.device).total_memory
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
gib = 1024**3
attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size
attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size
num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
autograd_memory = 2 * gib * num_devices # gpu memory used for intermediate tensors in rpc_backward
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block))
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
logger.info(
@ -260,6 +286,7 @@ class Server:
sender_threads=self.sender_threads,
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
start=True,
)
try:
@ -352,6 +379,7 @@ class ModuleContainer(threading.Thread):
expiration: Optional[float],
use_auth_token: Optional[str],
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@ -367,7 +395,9 @@ class ModuleContainer(threading.Thread):
joining_announcer.start()
logger.info(f"Announced that blocks {block_indices} are joining")
memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
memory_cache = MemoryCache(attn_cache_size, alloc_timeout)
blocks = {}
try:
for module_uid, block_index in zip(module_uids, block_indices):
@ -380,18 +410,13 @@ class ModuleContainer(threading.Thread):
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
if load_in_8bit:
block = replace_8bit_linear(block)
block = block.to(device)
for param in block.parameters():
param.requires_grad = False
backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
blocks[module_uid] = TransformerBackend(
module_uid,
block,
config=block_config,
memory_cache=memory_cache,
backend_dtype=backend_dtype,
args_schema=(
@ -451,6 +476,7 @@ class ModuleContainer(threading.Thread):
request_timeout: float,
session_timeout: float,
step_timeout: float,
device: Union[str, torch.device],
start: bool,
**kwargs,
):
@ -469,7 +495,8 @@ class ModuleContainer(threading.Thread):
)
for _ in range(num_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.runtime = Runtime(self.module_backends, device=None, **kwargs)
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
self.online_announcer = ModuleAnnouncerThread(
list(self.module_backends.keys()),
dht,

@ -5,7 +5,7 @@ import time
from concurrent.futures._base import PENDING
from dataclasses import dataclass, field
from queue import PriorityQueue
from typing import Any, List, Optional, Sequence, Tuple
from typing import Any, List, Optional, Sequence, Tuple, Union
import torch
from hivemind import get_logger
@ -43,6 +43,7 @@ class PrioritizedTaskPool(TaskPoolBase):
:param name: pool name, used for logging
:param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
:param device: if specified, input tensors will be moved to that device by default
:param start: if True, start automatically at the end of __init__
"""
@ -52,11 +53,13 @@ class PrioritizedTaskPool(TaskPoolBase):
max_batch_size: int,
name: str,
min_batch_size=1,
device: Optional[torch.device] = None,
daemon=True,
start=False,
):
super().__init__(process_func, daemon=daemon, name=name)
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.device = device
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
@ -101,7 +104,7 @@ class PrioritizedTaskPool(TaskPoolBase):
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
self.terminate()
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
"""Add task to this pool's queue, return Future for its output"""
future = MPFuture()
# Remove shmem from MPFuture. This disables the .cancel() feature but
@ -129,10 +132,9 @@ class PrioritizedTaskPool(TaskPoolBase):
self, timeout: Optional[float] = None, device: Optional[torch.device] = None
) -> Tuple[Any, List[torch.Tensor]]:
"""receive next batch of arrays"""
device = device if device is not None else self.device
task = self._ordered_tasks.get(block=True, timeout=timeout)
batch_inputs = [
tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
]
batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]
self._dispatched_tasks[task.uid] = task
self.batch_receiver.recv() # reduce the number of active batches
if not self._ordered_tasks.empty():
@ -142,11 +144,7 @@ class PrioritizedTaskPool(TaskPoolBase):
def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
batch_outputs = [
tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
for tensor in batch_outputs
]
batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs]
task = self._dispatched_tasks.pop(uid, None)
if task is None:
logger.error(
@ -182,3 +180,13 @@ class PrioritizedTaskPool(TaskPoolBase):
assert len(item) == 2
self._priority.value = float(item[0])
self._oldest_undispatched_timestamp.value = float(item[1])
def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False):
if isinstance(arg, torch.Tensor):
arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad)
# note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor
# produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed)
if share_memory:
arg = arg.share_memory_()
return arg

@ -2,9 +2,10 @@ import fcntl
import json
import os
import time
from collections import Counter
from hashlib import sha256
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Sequence, Union
import torch
from hivemind.utils.logging import get_logger
@ -12,7 +13,7 @@ from transformers import BloomConfig
from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.convert_block import convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__file__)
@ -37,6 +38,7 @@ def get_host_throughput(
dtype: Union[str, torch.dtype],
*,
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
force_eval: bool = False,
cache_dir: Optional[str] = None,
) -> float:
@ -57,6 +59,9 @@ def get_host_throughput(
cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
if len(tensor_parallel_devices) > 1:
for i, device_i in enumerate(tensor_parallel_devices):
cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}"
cache = {}
try:
@ -69,7 +74,9 @@ def get_host_throughput(
cache = {}
if cache_key not in cache:
cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit)
cache[cache_key] = measure_throughput_info(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
)
try:
os.makedirs(cache_path.parent, exist_ok=True)
@ -87,6 +94,7 @@ def measure_throughput_info(
dtype: torch.dtype,
*,
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
) -> float:
"""Measure network and compute throughput in forward pass tokens per second"""
@ -95,7 +103,9 @@ def measure_throughput_info(
)
return min(
measure_network_rps(config),
measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit),
measure_compute_rps(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
),
)
@ -129,14 +139,15 @@ def measure_compute_rps(
dtype: torch.dtype,
*,
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
n_tokens: int = 16,
n_steps: int = 500,
) -> float:
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = WrappedBloomBlock(config).to(dtype)
if load_in_8bit:
block = replace_8bit_linear(block)
block = block.to(device)
block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
cache = None
elapsed = 0
@ -149,9 +160,13 @@ def measure_compute_rps(
elapsed += time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
devices_repr = get_device_name(device)
if len(tensor_parallel_devices) > 1:
device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices)))
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
logger.info(
f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): "
f"{device_rps:.1f} RPS"
f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS"
)
return device_rps

@ -1,39 +0,0 @@
import bitsandbytes as bnb
import torch
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
def replace_8bit_linear(model, threshold=6.0):
"""
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
be kept as a `torch.nn.Linear` module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
threshold (`float`, *optional*):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""
for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
model._modules[n] = CustomLinear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
)
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False
).to(module.weight.dtype)
model._modules[n].bias = module.bias
return model

@ -0,0 +1,132 @@
"""
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
"""
import re
from typing import Sequence
import bitsandbytes as bnb
import tensor_parallel as tp
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tensor_parallel.slicing_configs import get_bloom_config
from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.bloom.block import WrappedBloomBlock
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
def convert_block(
block: WrappedBloomBlock,
config: BloomConfig,
tensor_parallel_devices: Sequence[torch.device],
output_device: torch.device,
load_in_8bit: bool,
threshold: float = 6.0,
freeze: bool = True,
) -> tp.TensorParallel:
"""
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
:note: some optimizations will modify the input block in-place!
:param block: a single transformer block, either pre-trained or newly initialized
:param config: HF transformers config for the full model
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
:param output_device: if tensor_parallel_devices is True, output
:param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
:param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
:param freeze: if True (default), make all module parameters non-trainable
:return: a module that acts like the original block, but runs with all specified optimizations
"""
if freeze:
for param in block.parameters():
param.requires_grad = False
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
if load_in_8bit:
block = replace_8bit_linear(block, threshold=threshold)
for shard, device in zip(block.module_shards, block.devices):
shard.to(device)
return block
def replace_8bit_linear(model: nn.Module, threshold=6.0):
"""
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
be kept as a `torch.nn.Linear` module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
threshold (`float`, *optional*):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""
for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
model._modules[n] = CustomLinear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
)
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False
).to(module.weight.dtype)
model._modules[n].bias = module.bias
return model
def make_tensor_parallel(
block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
):
assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
tp_config = get_bloom_config(model_config, devices)
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
total_heads = 0
for tp_shard in tp_block.module_shards:
for submodule in tp_shard.modules():
if isinstance(submodule, BloomAttention):
total_heads += submodule.num_heads
assert total_heads == model_config.n_head
return tp_block
def check_device_balance(devices: Sequence[torch.device]):
if not all(device.type == "cuda" for device in devices):
logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
return
unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
if len(unique_device_capabilities) > 1:
logger.warning(
f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
)
memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
used_memory = min(memory_per_device) * len(memory_per_device)
wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
if wasted_memory_rate > 0.05:
logger.warning(
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
f"Consider running high-memory GPUs in a separate server."
)

@ -7,10 +7,17 @@ from petals.server.throughput import measure_compute_rps, measure_network_rps
@pytest.mark.forked
def test_throughput_basic():
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_throughput_basic(tensor_parallel: bool):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps(
config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
config,
device=torch.device("cpu"),
dtype=torch.bfloat16,
load_in_8bit=False,
tensor_parallel_devices=tensor_parallel_devices,
n_steps=10,
)
assert isinstance(compute_rps, float) and compute_rps > 0
network_rps = measure_network_rps(config)

@ -13,7 +13,7 @@ from petals.dht_utils import get_remote_module
@pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)

@ -1,6 +1,7 @@
import pytest
import torch
from hivemind import DHT, BatchTensorDescriptor, get_logger
import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
from hivemind.proto import runtime_pb2
from test_utils import *
@ -39,10 +40,10 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad)
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
@ -58,7 +59,7 @@ def test_remote_sequential():
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01
absmax = abs(full_grad).max()
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
class DummyCustomSequenceManager(RemoteSequenceManager):
@ -87,9 +88,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
remote_sequential = RemoteSequential(config, dht)
inputs = torch.randn(batch_size, seq_len, config.hidden_size)
output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
input_prompts = input_prompts.detach().requires_grad_(True)
@ -117,10 +118,10 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
(outputs_ref,) = block(outputs_ref)
assert torch.allclose(outputs_ref, outputs)
assert torch.allclose(outputs_ref, outputs, atol=1e-3)
(outputs_ref * output_proj).sum().backward()
assert input_prompts_ref.grad is not None
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
assert intermediate_prompts_ref.grad is not None
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)

@ -0,0 +1,46 @@
import random
import pytest
import torch
import transformers
from tensor_parallel import TensorParallel
from tensor_parallel.slicing_configs import get_bloom_config
from test_utils import MODEL_NAME
from petals.bloom.from_pretrained import load_pretrained_block
@pytest.mark.forked
@pytest.mark.parametrize("custom_config", [True, False])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
def test_tp_block(devices, custom_config):
block_index = random.randint(0, 10)
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
tp_config = None
if custom_config:
tp_config = get_bloom_config(model_config, devices)
batch_size = 2
prefix_length = 5
test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
grad_proj = torch.rand_like(test_inputs1)
y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
y_ref.backward(grad_proj)
block_tp = TensorParallel(block, devices, config=tp_config)
y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
y_ours.backward(grad_proj)
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
assert torch.allclose(y_ours, y_ref, atol=1e-6)
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)
Loading…
Cancel
Save