diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index aa8b114..af6299b 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -86,16 +86,16 @@ jobs: sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \ + python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \ --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log & SERVER3_PID=$! - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \ + python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \ --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log & SERVER4_PID=$! python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log & + --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu --torch_dtype float32 &> server5.log & SERVER5_PID=$! tail -n 100 -f server*.log & diff --git a/setup.cfg b/setup.cfg index 15cdfe0..effa114 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = protobuf>=3.20.3,<4.0dev speedtest-cli==2.1.3 hivemind==1.1.3 + tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 79c1b9d..e089937 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -129,8 +129,12 @@ def main(): parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained") parser.add_argument('--load_in_8bit', type=str, default=None, - help="Convert the loaded model into mixed-8bit quantized model. " + help="Convert the loaded transformer blocks into mixed-8bit quantized model. " "Default: True if GPU is available. Use `--load_in_8bit False` to disable this") + parser.add_argument("--tensor_parallel_devices", nargs='+', default=None, + help= + "Split each block between the specified GPUs such that each device holds a portion of every " + "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism") parser.add_argument("--skip_reachability_check", action='store_true', help="Skip checking this server's reachability via health.petals.ml " diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 919c8c1..d5a7181 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +import dataclasses from dataclasses import dataclass from enum import Enum -from typing import Any, Dict +from typing import Any, Dict, Tuple from hivemind import PeerID +from petals.server.memory_cache import Handle + ModuleUID = str UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention" CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" @@ -39,3 +44,9 @@ class RemoteSpanInfo: RPCInfo = Dict[str, Any] + + +@dataclasses.dataclass(frozen=True) +class InferenceMetadata: + prefix_length: int + cache_handles: Tuple[Handle, ...] diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 9aa4ea5..67b03c0 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -1,12 +1,19 @@ """Code for serving bloom blocks via hivemind-server""" +from __future__ import annotations + +from itertools import chain from typing import Any, Dict, Sequence, Tuple import torch -from hivemind import BatchTensorDescriptor +from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger +from tensor_parallel import TensorParallel +from tensor_parallel.tensor_parallel import PerDeviceTensors +from transformers import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention -from petals.bloom.block import WrappedBloomBlock +from petals.data_structures import InferenceMetadata from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy @@ -17,9 +24,10 @@ logger = get_logger(__file__) class TransformerBackend(ModuleBackend): """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference""" - def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): + def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): super().__init__(*args, **kwargs) - assert isinstance(self.module, WrappedBloomBlock) + assert isinstance(self.module, TensorParallel) + self.config = config self.memory_cache = memory_cache for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" @@ -27,18 +35,26 @@ class TransformerBackend(ModuleBackend): assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" max_batch_size = self.forward_pool.max_batch_size + device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( - self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference" + self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" ) self.forward_pool = PrioritizedTaskPool( - self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward" + self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" ) self.backward_pool = PrioritizedTaskPool( - self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward" + self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" ) assert backend_dtype is not None self.dtype = backend_dtype + self.shard_num_heads = [] + for shard in self.module.module_shards: + for submodule in shard.modules(): + if isinstance(submodule, BloomAttention): + self.shard_num_heads.append(submodule.num_heads) + assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head + self.inference_schema = ( ( *self.args_schema, @@ -48,44 +64,60 @@ class TransformerBackend(ModuleBackend): self.kwargs_schema, ) + def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: + """Create tensor descriptors for attention cache tensors used during inference_step""" + head_dim = self.config.hidden_size // self.config.n_head + cache_tensors = [] + for device, num_heads in zip(self.module.devices, self.shard_num_heads): + keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) + values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device) + cache_tensors.extend((keys, values)) + return cache_tensors + def inference_step( - self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor + self, + hidden_states: torch.Tensor, + hypo_ids: torch.LongTensor, + inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: - num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim with torch.inference_mode(): assert ( hidden_states.ndim == 3 ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - cache_handle, rel_index, prefix_length = map(int, cache_metadata[0]) - - with self.memory_cache.use_cache(cache_handle) as cache: - batch_size = cache.shape[2] - max_length = cache.shape[-1] // (head_dim * num_heads) - assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4 - if not is_dummy(hypo_ids): - assert hypo_ids.shape[0] == batch_size - cache[rel_index, :, :] = cache[rel_index, :, hypo_ids] # in-place reorder cache by hypo ids - key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length) - value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim) - - key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] - value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim] - logger.debug( - f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}" - ) - hidden_states, (new_key, new_value) = self.module.forward( - hidden_states, layer_past=(key_past, value_past), use_cache=True - ) - new_length = new_key.shape[-1] - assert new_length > prefix_length - assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0] - assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length - new_key = new_key.view(batch_size, num_heads, head_dim, -1) - new_value = new_value.view(batch_size, num_heads, -1, head_dim) - key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] - value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] + with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + self._reorder_cache_inplace(cache_tensors, hypo_ids) + layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) + hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) + self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) return (hidden_states,) + def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor): + """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" + if not is_dummy(hypo_ids): + for cache_tensor in cache_tensors: + cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids + + def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]: + """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past""" + key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2]) + for i in range(len(key_cache)): + key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] + value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim] + layer_past = tuple(chain(*zip(key_cache, value_cache))) + return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past + + def _update_cache_inplace( + self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int + ): + """Writes new key/value tensors back into cache, works in-place""" + _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape + for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]): + new_key = new_key.view(*cache_key.shape[:3], new_length) + cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] + for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]): + new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim) + cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] + def get_pools(self) -> Sequence[PrioritizedTaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index ff66e4b..387431a 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio import contextlib +from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch @@ -8,10 +11,10 @@ from hivemind import ( DHT, MSGPackSerializer, P2PContext, - TensorDescriptor, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, + nested_pack, serialize_torch_tensor, ) from hivemind.moe.server.connection_handler import ConnectionHandler @@ -21,8 +24,9 @@ from hivemind.utils.asyncio import amap_in_executor, anext from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming -from petals.data_structures import CHAIN_DELIMITER, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID from petals.server.backend import TransformerBackend +from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from petals.utils.misc import DUMMY, is_dummy @@ -122,17 +126,12 @@ class TransformerConnectionHandler(ConnectionHandler): point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 - - cache_metadata = torch.tensor( - [[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64 - ) # [cache_handle, rel_index, prefix_length] prefix_length = 0 - async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle: + async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: + assert len(cache_handles) == len(requested_backends) while request.tensors: # iterate while user is willing to supply tensors - hidden_states, prompts, hypo_ids = [ - deserialize_torch_tensor(tensor) for tensor in request.tensors - ] + hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) @@ -155,16 +154,14 @@ class TransformerConnectionHandler(ConnectionHandler): ) # run request tensors through all requested modules, update caches - for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)): + for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt if hidden_states.numel() == 0: continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. # when user wants to pre-allocate cache or check that server *can* allocate that cache - cache_metadata[:] = torch.tensor( - [cache_handle, rel_index, prefix_length], dtype=torch.int64 - ) + metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles)) assert isinstance( hidden_states, torch.Tensor ), f"hidden states must be tensor, got {type(hidden_states)}" @@ -175,7 +172,6 @@ class TransformerConnectionHandler(ConnectionHandler): backend.inference_pool, PrioritizedTaskPool ), "petals support only prioritized pools" priority = self._prioritizer.prioritize( - cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends), @@ -183,7 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler): type="inference", ) (hidden_states,) = await backend.inference_pool.submit_task( - hidden_states, hypo_ids, cache_metadata, priority=priority + hidden_states, hypo_ids, metadata, priority=priority ) # serialize and send last layer outputs @@ -355,28 +351,14 @@ class TransformerConnectionHandler(ConnectionHandler): @contextlib.asynccontextmanager async def _allocate_cache( self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int - ) -> Sequence[int]: - """Allocate memory cache for all transformer blocks, return cache handle""" - - n_blocks = len(backends) - backend = backends[0] - n_heads = backend.module.self_attention.num_heads - head_dim = backend.module.self_attention.head_dim - descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype) - alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 - - gib = 1024**3 - cur_size = backend.memory_cache.current_size_bytes - max_size = backend.memory_cache.max_size_bytes - friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" - logger.info( - f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), " - f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" - ) - - async with backend.memory_cache.allocate_cache(descr) as handle: - logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)") - yield handle + ) -> Sequence[Sequence[Handle, ...]]: + """ + Allocate memory cache for all transformer blocks, return cache handle + :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend + """ + descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends] + async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles: + yield nested_pack(handles, descriptors) def _log_request( self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 53c1a7d..0e39cf5 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -10,7 +10,7 @@ import ctypes import multiprocessing as mp import os import time -from typing import AsyncContextManager, Dict, Optional, Union +from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple import hivemind import torch @@ -26,10 +26,9 @@ Handle = int class MemoryCache: """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" - def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float): + def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float): self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) self.alloc_timeout = alloc_timeout - self.device = device self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) @@ -57,26 +56,48 @@ class MemoryCache: self._handle_counter.value = value @contextlib.asynccontextmanager - async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]: + async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]: """ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed. - :param descr: allocate a tensor of this size, dtype, etc + :param descriptors: one or more tensors tensor of this size, dtype, etc + + :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices; + if not, it will count maximum tensor allocation across devices for the purposes of size limit :note: This function should be called by connection handlers, it can be called concurrently from multiple processes. Furthermore, it can be called concurrently with at most one use_cache call in runtime. """ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime" - assert descr.device is None and descr - - alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 - alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr)) + assert all(descr.device is not None for descr in descriptors), "please specify allocated devices" + max_alloc_size = self.get_allocation_size(*descriptors) + + gib = 1024**3 + cur_size, max_size = self.current_size_bytes, self.max_size_bytes + friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" + logger.info( + f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), " + f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" + ) + + alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors)) try: - yield await shield_and_wait(alloc_task) + handles = await shield_and_wait(alloc_task) + logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") + yield handles finally: - await shield_and_wait(self._schedule_free(alloc_size, alloc_task)) - - async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle: + await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) + + @staticmethod + def get_allocation_size(*descriptors: TensorDescriptor) -> int: + """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum""" + alloc_size_by_device = {} + for descr in descriptors: + tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 + alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size + return max(alloc_size_by_device.values()) + + async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]: """ This method should be called inside asyncio.shield() because: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation @@ -87,11 +108,11 @@ class MemoryCache: if self.current_size_bytes + alloc_size > self.max_size_bytes: await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) async with hivemind.utils.enter_asynchronously(self._lock_metadata): - handle = int(self.handle_counter) + handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) self.current_size_bytes += alloc_size - self.handle_counter += 1 # note: this will eventually overflow and it is okay - self._pipe_send.send((handle, descr)) - return handle + self.handle_counter += len(handles) # note: this will eventually overflow and it is okay + self._pipe_send.send((handles, descriptors)) + return handles async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task): """ @@ -102,10 +123,10 @@ class MemoryCache: if alloc_task.exception() is not None: return - handle = alloc_task.result() + handles = alloc_task.result() async with hivemind.utils.enter_asynchronously(self._lock_metadata): - self._pipe_send.send((handle, None)) # signal runtime to free that handle + self._pipe_send.send((handles, None)) # signal runtime to free these handles self.current_size_bytes -= alloc_size self._memory_freed_event.set() @@ -125,11 +146,11 @@ class MemoryCache: self._memory_freed_event.clear() @contextlib.contextmanager - def use_cache(self, handle: Handle) -> torch.Tensor: + def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]: """ - Return a tensor that was previously allocated with try_allocate_cache, + Return one or more tensors previously allocated with allocate_cache, - :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism. + :note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism. However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache """ assert os.getpid() == self.runtime_pid @@ -138,20 +159,20 @@ class MemoryCache: with self._lock_metadata: # read creation/deletion requests from connection handlers while self._pipe_recv.poll(): - recv_handle, recv_data = self._pipe_recv.recv() - if isinstance(recv_data, TensorDescriptor): - self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device) - elif recv_data is None: - if recv_handle not in self._allocated_tensors: - logger.warning( - f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle" - ) - self._allocated_tensors.pop(recv_handle, None) - else: - logger.error(f"MemoryCache pipe received unexpected message: {recv_data}") - - assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" - yield self._allocated_tensors[handle] + recv_handles, recv_data = self._pipe_recv.recv() + if recv_data is not None: # create new tensors + assert len(recv_handles) == len(recv_data) + for handle, descr in zip(recv_handles, recv_data): + self._allocated_tensors[handle] = descr.make_zeros() + assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" + else: # delete tensors by handle + for handle in recv_handles: + if handle not in self._allocated_tensors: + logger.warning( + f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" + ) + self._allocated_tensors.pop(handle, None) + yield tuple(self._allocated_tensors[handle] for handle in handles) class AllocationFailed(Exception): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index f7006cc..a8927aa 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -6,7 +6,7 @@ import multiprocessing as mp import random import threading import time -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union import numpy as np import psutil @@ -29,7 +29,7 @@ from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.throughput import get_host_throughput -from petals.utils.convert_8bit import replace_8bit_linear +from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__file__) @@ -76,6 +76,7 @@ class Server: mean_block_selection_delay: float = 2.5, use_auth_token: Optional[str] = None, load_in_8bit: Optional[bool] = None, + tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, **kwargs, ): @@ -128,6 +129,8 @@ class Server: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) + if device.type == "cuda" and device.index is None: + device = torch.device(device.type, index=0) self.device = device if isinstance(torch_dtype, str): @@ -141,6 +144,13 @@ class Server: logger.info("Model weights will be loaded in 8-bit format") self.load_in_8bit = load_in_8bit + if tensor_parallel_devices is None: + tensor_parallel_devices = (device,) + self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices)) + if len(self.tensor_parallel_devices) > 1: + logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}") + check_device_balance(self.tensor_parallel_devices) + assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() @@ -174,6 +184,7 @@ class Server: device, torch_dtype, load_in_8bit=load_in_8bit, + tensor_parallel_devices=self.tensor_parallel_devices, force_eval=(throughput == "eval"), cache_dir=cache_dir, ) @@ -214,13 +225,28 @@ class Server: self.converted_model_name_or_path == "bigscience/bloom-petals" ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually" assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually" + num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 + + if num_devices > 1: + memory_per_device = tuple( + torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices + ) + total_memory = min(memory_per_device) * num_devices + if max(memory_per_device) / min(memory_per_device) > 1.5: + raise ValueError( + "GPU devices have highly uneven memory, which makes tensor parallelism inefficient. " + "Please launch individual servers on each GPU or set --num_blocks manually to " + "override this exception." + ) + else: + total_memory = torch.cuda.get_device_properties(self.device).total_memory - total_memory = torch.cuda.get_device_properties(self.device).total_memory block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit) gib = 1024**3 - attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size + attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size - num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block)) + autograd_memory = 2 * gib * num_devices # gpu memory used for intermediate tensors in rpc_backward + num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" logger.info( @@ -260,6 +286,7 @@ class Server: sender_threads=self.sender_threads, use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, + tensor_parallel_devices=self.tensor_parallel_devices, start=True, ) try: @@ -352,6 +379,7 @@ class ModuleContainer(threading.Thread): expiration: Optional[float], use_auth_token: Optional[str], load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], **kwargs, ) -> ModuleContainer: module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] @@ -367,7 +395,9 @@ class ModuleContainer(threading.Thread): joining_announcer.start() logger.info(f"Announced that blocks {block_indices} are joining") - memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout) + assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) + + memory_cache = MemoryCache(attn_cache_size, alloc_timeout) blocks = {} try: for module_uid, block_index in zip(module_uids, block_indices): @@ -380,18 +410,13 @@ class ModuleContainer(threading.Thread): cache_dir=cache_dir, max_disk_space=max_disk_space, ) + block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) - if load_in_8bit: - block = replace_8bit_linear(block) - - block = block.to(device) - for param in block.parameters(): - param.requires_grad = False - - backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype + backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype blocks[module_uid] = TransformerBackend( module_uid, block, + config=block_config, memory_cache=memory_cache, backend_dtype=backend_dtype, args_schema=( @@ -451,6 +476,7 @@ class ModuleContainer(threading.Thread): request_timeout: float, session_timeout: float, step_timeout: float, + device: Union[str, torch.device], start: bool, **kwargs, ): @@ -469,7 +495,8 @@ class ModuleContainer(threading.Thread): ) for _ in range(num_handlers) ] - self.runtime = Runtime(self.module_backends, **kwargs) + self.runtime = Runtime(self.module_backends, device=None, **kwargs) + # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. self.online_announcer = ModuleAnnouncerThread( list(self.module_backends.keys()), dht, diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index 1374f94..330679c 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -5,7 +5,7 @@ import time from concurrent.futures._base import PENDING from dataclasses import dataclass, field from queue import PriorityQueue -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from hivemind import get_logger @@ -43,6 +43,7 @@ class PrioritizedTaskPool(TaskPoolBase): :param name: pool name, used for logging :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more + :param device: if specified, input tensors will be moved to that device by default :param start: if True, start automatically at the end of __init__ """ @@ -52,11 +53,13 @@ class PrioritizedTaskPool(TaskPoolBase): max_batch_size: int, name: str, min_batch_size=1, + device: Optional[torch.device] = None, daemon=True, start=False, ): super().__init__(process_func, daemon=daemon, name=name) self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size + self.device = device self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime @@ -101,7 +104,7 @@ class PrioritizedTaskPool(TaskPoolBase): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") self.terminate() - def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture: + def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: """Add task to this pool's queue, return Future for its output""" future = MPFuture() # Remove shmem from MPFuture. This disables the .cancel() feature but @@ -129,10 +132,9 @@ class PrioritizedTaskPool(TaskPoolBase): self, timeout: Optional[float] = None, device: Optional[torch.device] = None ) -> Tuple[Any, List[torch.Tensor]]: """receive next batch of arrays""" + device = device if device is not None else self.device task = self._ordered_tasks.get(block=True, timeout=timeout) - batch_inputs = [ - tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args - ] + batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args] self._dispatched_tasks[task.uid] = task self.batch_receiver.recv() # reduce the number of active batches if not self._ordered_tasks.empty(): @@ -142,11 +144,7 @@ class PrioritizedTaskPool(TaskPoolBase): def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]): """send results for a processed batch, previously loaded through load_batch_to_runtime""" - batch_outputs = [ - tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad) - for tensor in batch_outputs - ] - + batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs] task = self._dispatched_tasks.pop(uid, None) if task is None: logger.error( @@ -182,3 +180,13 @@ class PrioritizedTaskPool(TaskPoolBase): assert len(item) == 2 self._priority.value = float(item[0]) self._oldest_undispatched_timestamp.value = float(item[1]) + + +def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False): + if isinstance(arg, torch.Tensor): + arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad) + # note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor + # produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed) + if share_memory: + arg = arg.share_memory_() + return arg diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index c24e710..73ad973 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -2,9 +2,10 @@ import fcntl import json import os import time +from collections import Counter from hashlib import sha256 from pathlib import Path -from typing import Optional, Union +from typing import Optional, Sequence, Union import torch from hivemind.utils.logging import get_logger @@ -12,7 +13,7 @@ from transformers import BloomConfig from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_8bit import replace_8bit_linear +from petals.utils.convert_block import convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__file__) @@ -37,6 +38,7 @@ def get_host_throughput( dtype: Union[str, torch.dtype], *, load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, cache_dir: Optional[str] = None, ) -> float: @@ -57,6 +59,9 @@ def get_host_throughput( cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}" cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}" + if len(tensor_parallel_devices) > 1: + for i, device_i in enumerate(tensor_parallel_devices): + cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}" cache = {} try: @@ -69,7 +74,9 @@ def get_host_throughput( cache = {} if cache_key not in cache: - cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit) + cache[cache_key] = measure_throughput_info( + config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + ) try: os.makedirs(cache_path.parent, exist_ok=True) @@ -87,6 +94,7 @@ def measure_throughput_info( dtype: torch.dtype, *, load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], ) -> float: """Measure network and compute throughput in forward pass tokens per second""" @@ -95,7 +103,9 @@ def measure_throughput_info( ) return min( measure_network_rps(config), - measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit), + measure_compute_rps( + config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + ), ) @@ -129,14 +139,15 @@ def measure_compute_rps( dtype: torch.dtype, *, load_in_8bit: bool, + tensor_parallel_devices: Sequence[torch.device], n_tokens: int = 16, n_steps: int = 500, ) -> float: + if not tensor_parallel_devices: + tensor_parallel_devices = (device,) with torch.inference_mode(): block = WrappedBloomBlock(config).to(dtype) - if load_in_8bit: - block = replace_8bit_linear(block) - block = block.to(device) + block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) cache = None elapsed = 0 @@ -149,9 +160,13 @@ def measure_compute_rps( elapsed += time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed + devices_repr = get_device_name(device) + if len(tensor_parallel_devices) > 1: + device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices))) + devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) + logger.info( - f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): " - f"{device_rps:.1f} RPS" + f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS" ) return device_rps diff --git a/src/petals/utils/convert_8bit.py b/src/petals/utils/convert_8bit.py deleted file mode 100644 index eeb29e7..0000000 --- a/src/petals/utils/convert_8bit.py +++ /dev/null @@ -1,39 +0,0 @@ -import bitsandbytes as bnb -import torch - -from petals.utils.linear8bitlt_patch import CustomLinear8bitLt - - -def replace_8bit_linear(model, threshold=6.0): - """ - A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` - library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): - 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA - version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ - bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) - The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should - be kept as a `torch.nn.Linear` module. - Parameters: - model (`torch.nn.Module`): - Input model or `torch.nn.Module` as the function is run recursively. - threshold (`float`, *optional*): - `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to - `6.0` as described by the paper. - """ - for n, module in model.named_children(): - if len(list(module.children())) > 0: - replace_8bit_linear(module, threshold) - - if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: - model._modules[n] = CustomLinear8bitLt( - module.in_features, - module.out_features, - module.bias is not None, - has_fp16_weights=False, - threshold=threshold, - ) - model._modules[n].weight = bnb.nn.Int8Params( - module.weight.data, requires_grad=False, has_fp16_weights=False - ).to(module.weight.dtype) - model._modules[n].bias = module.bias - return model diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py new file mode 100644 index 0000000..794ecd9 --- /dev/null +++ b/src/petals/utils/convert_block.py @@ -0,0 +1,132 @@ +""" +Tools for converting transformer blocks, applying quantization and/or tensor parallelism +""" +import re +from typing import Sequence + +import bitsandbytes as bnb +import tensor_parallel as tp +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from tensor_parallel.slicing_configs import get_bloom_config +from transformers import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomAttention + +from petals.bloom.block import WrappedBloomBlock +from petals.utils.linear8bitlt_patch import CustomLinear8bitLt + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +def convert_block( + block: WrappedBloomBlock, + config: BloomConfig, + tensor_parallel_devices: Sequence[torch.device], + output_device: torch.device, + load_in_8bit: bool, + threshold: float = 6.0, + freeze: bool = True, +) -> tp.TensorParallel: + """ + Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization + + :note: some optimizations will modify the input block in-place! + :param block: a single transformer block, either pre-trained or newly initialized + :param config: HF transformers config for the full model + :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices + :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity) + :param output_device: if tensor_parallel_devices is True, output + :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint + :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 ) + :param freeze: if True (default), make all module parameters non-trainable + :return: a module that acts like the original block, but runs with all specified optimizations + + """ + if freeze: + for param in block.parameters(): + param.requires_grad = False + + block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) + + if load_in_8bit: + block = replace_8bit_linear(block, threshold=threshold) + + for shard, device in zip(block.module_shards, block.devices): + shard.to(device) + + return block + + +def replace_8bit_linear(model: nn.Module, threshold=6.0): + """ + A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should + be kept as a `torch.nn.Linear` module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + threshold (`float`, *optional*): + `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to + `6.0` as described by the paper. + """ + for n, module in model.named_children(): + if len(list(module.children())) > 0: + replace_8bit_linear(module, threshold) + + if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: + assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}" + model._modules[n] = CustomLinear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=threshold, + ) + model._modules[n].weight = bnb.nn.Int8Params( + module.weight.data, requires_grad=False, has_fp16_weights=False + ).to(module.weight.dtype) + model._modules[n].bias = module.bias + return model + + +def make_tensor_parallel( + block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device +): + assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt)) + tp_config = get_bloom_config(model_config, devices) + del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] + tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True) + total_heads = 0 + for tp_shard in tp_block.module_shards: + for submodule in tp_shard.modules(): + if isinstance(submodule, BloomAttention): + total_heads += submodule.num_heads + assert total_heads == model_config.n_head + return tp_block + + +def check_device_balance(devices: Sequence[torch.device]): + if not all(device.type == "cuda" for device in devices): + logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk") + return + unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices)) + if len(unique_device_capabilities) > 1: + logger.warning( + f"Found GPUs with uneven capabilities: {unique_device_capabilities}. " + f"Using GPUs with different performance will cause the server to wait for the slowest GPU." + ) + + memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices) + used_memory = min(memory_per_device) * len(memory_per_device) + wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device) + if wasted_memory_rate > 0.05: + logger.warning( + f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. " + f"Consider running high-memory GPUs in a separate server." + ) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 46c4bfe..554127f 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -7,10 +7,17 @@ from petals.server.throughput import measure_compute_rps, measure_network_rps @pytest.mark.forked -def test_throughput_basic(): +@pytest.mark.parametrize("tensor_parallel", [False, True]) +def test_throughput_basic(tensor_parallel: bool): config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( - config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10 + config, + device=torch.device("cpu"), + dtype=torch.bfloat16, + load_in_8bit=False, + tensor_parallel_devices=tensor_parallel_devices, + n_steps=10, ) assert isinstance(compute_rps, float) and compute_rps > 0 network_rps = measure_network_rps(config) diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index ab41ce8..664f255 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -13,7 +13,7 @@ from petals.dht_utils import get_remote_module @pytest.mark.forked -def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): +def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) config = DistributedBloomConfig.from_pretrained(MODEL_NAME) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index ed76696..a8e585f 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -1,6 +1,7 @@ import pytest import torch -from hivemind import DHT, BatchTensorDescriptor, get_logger +import torch.nn.functional as F +from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler from hivemind.proto import runtime_pb2 from test_utils import * @@ -39,10 +40,10 @@ def test_remote_sequential(): assert hidden.shape == test_inputs.shape assert hidden.requires_grad second_half_outputs = second_half(hidden) - assert torch.allclose(second_half_outputs, full_outputs) + assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4) (second_half_outputs * grad_proj).sum().backward() - assert torch.allclose(test_inputs.grad, full_grad) + assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3) # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] @@ -58,7 +59,7 @@ def test_remote_sequential(): assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used" assert abs(approx_outputs - full_outputs).mean() < 0.01 absmax = abs(full_grad).max() - assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01 + assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05 class DummyCustomSequenceManager(RemoteSequenceManager): @@ -87,9 +88,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) remote_sequential = RemoteSequential(config, dht) - inputs = torch.randn(batch_size, seq_len, config.hidden_size) - output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size) - input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True) + inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) + output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1) + input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1) intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True) input_prompts = input_prompts.detach().requires_grad_(True) @@ -117,10 +118,10 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32) (outputs_ref,) = block(outputs_ref) - assert torch.allclose(outputs_ref, outputs) + assert torch.allclose(outputs_ref, outputs, atol=1e-3) (outputs_ref * output_proj).sum().backward() assert input_prompts_ref.grad is not None - assert torch.allclose(input_prompts_ref.grad, input_prompts.grad) + assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2) assert intermediate_prompts_ref.grad is not None - assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad) + assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py new file mode 100644 index 0000000..40eb1ee --- /dev/null +++ b/tests/test_tensor_parallel.py @@ -0,0 +1,46 @@ +import random + +import pytest +import torch +import transformers +from tensor_parallel import TensorParallel +from tensor_parallel.slicing_configs import get_bloom_config +from test_utils import MODEL_NAME + +from petals.bloom.from_pretrained import load_pretrained_block + + +@pytest.mark.forked +@pytest.mark.parametrize("custom_config", [True, False]) +@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4]) +def test_tp_block(devices, custom_config): + block_index = random.randint(0, 10) + model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME) + block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0]) + + tp_config = None + if custom_config: + tp_config = get_bloom_config(model_config, devices) + + batch_size = 2 + prefix_length = 5 + + test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0]) + test_inputs2 = test_inputs1.detach().clone().requires_grad_(True) + test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0]) + test_prefix2 = test_prefix1.detach().clone().requires_grad_(True) + grad_proj = torch.rand_like(test_inputs1) + + y_prefix_ref, layer_past = block(test_prefix1, use_cache=True) + y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past) + y_ref.backward(grad_proj) + + block_tp = TensorParallel(block, devices, config=tp_config) + y_prefix, layer_past = block_tp(test_prefix2, use_cache=True) + y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past) + y_ours.backward(grad_proj) + + assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6) + assert torch.allclose(y_ours, y_ref, atol=1e-6) + assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4) + assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)