From 158621677bac37572c2cf256c419472d507d451c Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 6 Sep 2023 19:43:30 +0400 Subject: [PATCH 01/11] Bump version to 2.2.0 (#502) --- README.md | 2 +- src/petals/__init__.py | 2 +- src/petals/models/falcon/config.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6987489..1f410ef 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@

-Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Falcon**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 4e4a9d0..f513f65 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,7 +17,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.1.0" +__version__ = "2.2.0" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/models/falcon/config.py b/src/petals/models/falcon/config.py index a1ae5e9..9fadede 100644 --- a/src/petals/models/falcon/config.py +++ b/src/petals/models/falcon/config.py @@ -31,6 +31,9 @@ class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): + if "180B" in model_name_or_path.upper(): + logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license") + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) From 5ce4f1a1598b1fca9fe6bd30cfbd85aa99bce2c7 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 15 Sep 2023 23:53:57 +0400 Subject: [PATCH 02/11] Store (start_block, end_block) in each DHT record for reliability (#510) This PR fixes gaps in the DHT server info caused by unavailable DHT keys. Now, one DHT key is enough to get info about all blocks hosted by a server - so we'll see info until all keys are unavailable. Also, this PR refactors `petals.client.routing` and `petals.server.block_selection` modules to use the common `compute_spans()` function (defined in `petals.utils.dht`) and `RemoteSpanInfo` class (defined in `petals.data_structures`). --- src/petals/client/routing/sequence_info.py | 65 ++++------------- src/petals/client/routing/sequence_manager.py | 4 -- src/petals/data_structures.py | 35 ++++++--- src/petals/server/block_selection.py | 72 +++++++------------ src/petals/server/server.py | 27 +++---- src/petals/utils/dht.py | 53 ++++++++++---- 6 files changed, 119 insertions(+), 137 deletions(-) diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index bce6712..2c9137b 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -1,17 +1,15 @@ import dataclasses import time -from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar +from typing import Iterable, List, Optional, Tuple from hivemind import get_logger from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState +from petals.utils.dht import compute_spans logger = get_logger(__name__) -T = TypeVar("T") - - @dataclasses.dataclass class RemoteSequenceInfo: """ @@ -30,7 +28,7 @@ class RemoteSequenceInfo: last_updated_time: Optional[float] @classmethod - def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T: + def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo": block_uids = tuple(block_uids) empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) empty_spans = tuple([] for _ in range(len(block_uids))) @@ -39,7 +37,7 @@ class RemoteSequenceInfo: def __getitem__(self, ix: slice): assert isinstance(ix, slice) block_uids, block_infos = self.block_uids[ix], self.block_infos[ix] - spans_by_priority, spans_containing_block = self.compute_spans(block_infos) + spans_by_priority, spans_containing_block = self._sort_spans(block_infos) return RemoteSequenceInfo( block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time ) @@ -47,60 +45,23 @@ class RemoteSequenceInfo: def __len__(self): return len(self.block_uids) - def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]): + def update_(self, new_block_infos: List[RemoteModuleInfo]): assert len(new_block_infos) == len(self.block_uids) for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): - if info is None: - logger.debug(f"Found no block info for block {uid}") - continue - if not isinstance(info, RemoteModuleInfo): - logger.warning(f"Unexpected dht entry type for {uid}: {info}") - continue - if not info.servers: - logger.debug(f"Found no active peers for block {uid}") - continue - if info.uid != uid: - logger.warning(f"The DHT entry for {uid} actually points to {info.uid}") - continue + assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}" self.block_infos[block_index].servers = info.servers - self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) + self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos) self.last_updated_time = time.perf_counter() @staticmethod - def compute_spans(block_infos: Sequence[RemoteModuleInfo]): - closed_spans = [] - active_spans = {} - for block_index, info in enumerate(block_infos): - if info is not None: - for peer_id, server_info in info.servers.items(): - if server_info.state != ServerState.ONLINE: - continue - if peer_id not in active_spans: - active_spans[peer_id] = RemoteSpanInfo( - peer_id=peer_id, - start=block_index, - end=block_index + 1, - server_info=server_info, - ) - else: # peer_id in active_spans - active_spans[peer_id].end = block_index + 1 - - for peer_id in list(active_spans.keys()): - if ( - info is None - or peer_id not in info.servers - or info.servers[peer_id].state != ServerState.ONLINE - or block_index == len(block_infos) - 1 - ): - closed_spans.append(active_spans.pop(peer_id)) - assert not active_spans, f"spans: {active_spans}" - - closed_spans.sort(key=lambda span: span.length, reverse=True) + def _sort_spans(block_infos: List[RemoteModuleInfo]): + spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values()) + spans_by_priority.sort(key=lambda span: span.length, reverse=True) - spans_containing_block = tuple(list() for _ in range(len(block_infos))) - for span in closed_spans: + spans_containing_block = tuple([] for _ in range(len(block_infos))) + for span in spans_by_priority: for block_index in range(span.start, span.end): spans_containing_block[block_index].append(span) - return closed_spans, spans_containing_block + return spans_by_priority, spans_containing_block diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 3e239b4..ed5224c 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -117,7 +117,6 @@ class RemoteSequenceManager: if state.sequence_info.last_updated_time is not None: assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch - self._need_latest_infos = True @staticmethod def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: @@ -346,9 +345,6 @@ class RemoteSequenceManager: ) for block_info in new_block_infos: - if not block_info: - continue - # Apply allow and block lists block_info.servers = { peer_id: server_info diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 7c86f14..9cbbf76 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -11,18 +11,15 @@ UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" -class ServerState(Enum): - OFFLINE = 0 - JOINING = 1 - ONLINE = 2 - - -RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) +def parse_uid(uid: ModuleUID) -> Tuple[str, int]: + assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs" + dht_prefix, index = uid.split(UID_DELIMITER) + return dht_prefix, int(index) @pydantic.dataclasses.dataclass class ModelInfo: - num_blocks: int + num_blocks: pydantic.conint(ge=1, strict=True) repository: Optional[str] = None def to_dict(self) -> dict: @@ -33,11 +30,23 @@ class ModelInfo: return cls(**source) +class ServerState(Enum): + OFFLINE = 0 + JOINING = 1 + ONLINE = 2 + + +RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + + @pydantic.dataclasses.dataclass class ServerInfo: state: ServerState throughput: RPS + start_block: Optional[pydantic.conint(ge=0, strict=True)] = None + end_block: Optional[pydantic.conint(ge=0, strict=True)] = None + public_name: Optional[str] = None version: Optional[str] = None @@ -83,9 +92,17 @@ class RemoteSpanInfo: server_info: ServerInfo @property - def length(self): + def length(self) -> int: return self.end - self.start + @property + def state(self) -> ServerState: + return self.server_info.state + + @property + def throughput(self) -> float: + return self.server_info.throughput + RPCInfo = Dict[str, Any] diff --git a/src/petals/server/block_selection.py b/src/petals/server/block_selection.py index cc050d4..441c0cd 100644 --- a/src/petals/server/block_selection.py +++ b/src/petals/server/block_selection.py @@ -1,54 +1,23 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List import numpy as np from hivemind import PeerID, get_logger -from petals.data_structures import RemoteModuleInfo, ServerState - -__all__ = ["choose_best_blocks", "should_choose_other_blocks"] +from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState +from petals.utils.dht import compute_spans logger = get_logger(__name__) -@dataclass -class Span: - start: int - end: int - throughput: float - state: ServerState - - @property - def length(self): - return self.end - self.start - - def move_to(self, new_start: int) -> None: - self.start, self.end = new_start, new_start + self.length - - -def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]: - spans = {} - throughputs = np.zeros(len(module_infos)) - for block, module in enumerate(module_infos): - if module is None: - continue - - # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers. - # If the order were not defined, we would get slightly different values due to floating point errors, - # which may cause excess block replacements. - for peer_id, server in sorted(module.servers.items()): - if server.state == ServerState.OFFLINE: - continue +def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray: + # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers. + # If the order were not defined, we would get slightly different values due to floating point errors, + # which may cause excess block replacements. - if peer_id in spans: - spans[peer_id].start = min(spans[peer_id].start, block) - spans[peer_id].end = max(spans[peer_id].start, block + 1) - else: - spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state) - - throughputs[block] += server.throughput - - return spans, throughputs + throughputs = np.zeros(total_blocks) + for span in sorted(spans.values(), key=lambda span: span.peer_id): + throughputs[span.start : span.end] += span.throughput + return throughputs def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: @@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: return min(options)[-1] -def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: - _, throughputs = compute_spans(module_infos) +def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]: + spans = compute_spans(module_infos, min_state=ServerState.JOINING) + throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) + start = _choose_best_start(throughputs, num_blocks) return list(range(start, start + num_blocks)) +def _move_span(span: RemoteSpanInfo, new_start: int): + span.start, span.end = new_start, new_start + span.length + + def should_choose_other_blocks( - local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float + local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float ) -> bool: if balance_quality > 1.0: return True # Forces rebalancing on each check (may be used for debugging purposes) - spans, throughputs = compute_spans(module_infos) + spans = compute_spans(module_infos, min_state=ServerState.JOINING) + throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) initial_throughput = throughputs.min() eps = 1e-3 @@ -88,7 +64,7 @@ def should_choose_other_blocks( return False # This server is on its best place already throughputs[local_span.start : local_span.end] += local_span.throughput * eps - local_span.move_to(new_start) + _move_span(local_span, new_start) throughputs[local_span.start : local_span.end] += local_span.throughput moved = True @@ -105,7 +81,7 @@ def should_choose_other_blocks( throughputs[span.start : span.end] += span.throughput * eps if span.start != new_start: - span.move_to(new_start) + _move_span(span, new_start) moved = True throughputs[span.start : span.end] += span.throughput diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fd9f766..82388aa 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -23,7 +23,7 @@ from transformers import PretrainedConfig import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype @@ -220,11 +220,10 @@ class Server: num_blocks = min(num_blocks, self.block_config.num_hidden_layers) if block_indices is not None: try: - first_block_index, last_block_index = block_indices.split(":") - first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index))) + start_block, end_block = [int(index.strip()) for index in block_indices.split(":")] except Exception as e: raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)") - block_indices = range(first_block_index, last_block_index) + block_indices = range(start_block, end_block) num_blocks = len(block_indices) self.strict_block_indices, self.num_blocks = block_indices, num_blocks @@ -703,11 +702,16 @@ class ModuleAnnouncerThread(threading.Thread): self.expiration = expiration self.trigger = threading.Event() + self.dht_prefix = parse_uid(module_uids[0])[0] + block_indices = [parse_uid(uid)[1] for uid in module_uids] + self.server_info.start_block = min(block_indices) + self.server_info.end_block = max(block_indices) + 1 + self.max_pinged = max_pinged - self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0] - block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] - start_block, end_block = min(block_indices), max(block_indices) + 1 - self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] + self.next_uids = [ + f"{self.dht_prefix}{UID_DELIMITER}{i}" + for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1) + ] self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: @@ -755,12 +759,11 @@ class ModuleAnnouncerThread(threading.Thread): def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True) - middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers} + middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers} pinged_servers = set(sample_up_to(middle_servers, self.max_pinged)) pinged_servers.discard(self.dht.peer_id) - if module_infos[-1] is not None: - # Sample servers hosting the block after the last one (most likely continuations) separately - pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) + # Sample servers hosting the block after the last one (most likely continuations) separately + pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) self.ping_aggregator.ping(list(pinged_servers)) diff --git a/src/petals/utils/dht.py b/src/petals/utils/dht.py index 0710f60..4faf74a 100644 --- a/src/petals/utils/dht.py +++ b/src/petals/utils/dht.py @@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo +from petals.data_structures import ( + CHAIN_DELIMITER, + UID_DELIMITER, + ModuleUID, + RemoteModuleInfo, + RemoteSpanInfo, + ServerInfo, + ServerState, + parse_uid, +) logger = get_logger(__name__) @@ -70,7 +79,7 @@ def get_remote_module_infos( *, latest: bool = False, return_future: bool = False, -) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: +) -> Union[List[RemoteModuleInfo], MPFuture]: return dht.run_coroutine( partial( _get_remote_module_infos, @@ -90,7 +99,7 @@ async def _get_remote_module_infos( active_adapter: Optional[str], expiration_time: Optional[DHTExpiration], latest: bool, -) -> List[Optional[RemoteModuleInfo]]: +) -> List[RemoteModuleInfo]: if latest: assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" expiration_time = math.inf @@ -99,14 +108,14 @@ async def _get_remote_module_infos( num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) - modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids) - for i, uid in enumerate(uids): - metadata = found[uid] + modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids] + for module_info in modules: + metadata = found[module_info.uid] if metadata is None or not isinstance(metadata.value, dict): if metadata is not None: - logger.warning(f"Incorrect metadata for {uid}: {metadata}") + logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}") continue - servers = {} + for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) @@ -116,9 +125,29 @@ async def _get_remote_module_infos( logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") continue - servers[peer_id] = server_info + module_info.servers[peer_id] = server_info except (TypeError, ValueError) as e: - logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") - if servers: - modules[i] = RemoteModuleInfo(uid, servers) + logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}") return modules + + +def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]: + block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0 + num_blocks = len(module_infos) + + spans = {} + for block_idx, module_info in enumerate(module_infos): + for peer_id, server_info in sorted(module_info.servers.items()): + if server_info.state.value < min_state.value: + continue + + if peer_id not in spans or spans[peer_id].state.value < server_info.state.value: + spans[peer_id] = RemoteSpanInfo( + peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info + ) + if server_info.start_block is not None and server_info.end_block is not None: + spans[peer_id].start = max(server_info.start_block - block_offset, 0) + spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks) + elif spans[peer_id].state == server_info.state: + spans[peer_id].end = max(spans[peer_id].end, block_idx + 1) + return spans From a2484b305374d275763e909fbda36a2b79338e30 Mon Sep 17 00:00:00 2001 From: FYY Date: Tue, 19 Sep 2023 20:01:23 -0400 Subject: [PATCH 03/11] Fix file locks in NFS-mounted directories (#517) Fix #515. --- src/petals/server/throughput.py | 2 +- src/petals/utils/disk_cache.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index bf71f44..c42bdb9 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -56,7 +56,7 @@ def get_server_throughput( # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) - with open(lock_path, "wb") as lock_fd: + with open(lock_path, "wb+") as lock_fd: logger.info("Loading throughput info") fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # The OS will release the lock when lock_fd is closed or the process is killed diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py index a26a0f5..5de47c8 100644 --- a/src/petals/utils/disk_cache.py +++ b/src/petals/utils/disk_cache.py @@ -22,7 +22,7 @@ def _blocks_lock(cache_dir: Optional[str], mode: int): lock_path = Path(cache_dir, BLOCKS_LOCK_FILE) os.makedirs(lock_path.parent, exist_ok=True) - with open(lock_path, "wb") as lock_fd: + with open(lock_path, "wb+") as lock_fd: fcntl.flock(lock_fd.fileno(), mode) # The OS will release the lock when lock_fd is closed or the process is killed yield From 1d9401ddceca53fd2a3d21e48c656242c65c6692 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 22 Sep 2023 06:16:32 +0400 Subject: [PATCH 04/11] Update README.md (#520) --- README.md | 109 ++++++------------------------------------------------ 1 file changed, 12 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 1f410ef..63449ae 100644 --- a/README.md +++ b/README.md @@ -8,14 +8,14 @@

-Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Falcon**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: +Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM # Choose any model available at https://health.petals.dev -model_name = "petals-team/StableBeluga2" +model_name = "petals-team/StableBeluga2" # This one is fine-tuned Llama 2 (70B) # Connect to a distributed network hosting model layers tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -31,9 +31,9 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... 🚀  Try now in Colab

-🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). +🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. -🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. +🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! @@ -81,9 +81,8 @@ python3 -m petals.cli.run_server petals-team/StableBeluga2 ## How does it work? -- Petals runs large language models like [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning. -- Single-batch inference runs at **up to 6 steps/sec** for **Llama 2** (70B) and ≈ 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec. -- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch. +- You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Single‑batch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps. +- You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.

@@ -113,99 +112,15 @@ Advanced guides: - Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) -## Benchmarks - -The benchmarks below are for BLOOM-176B: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NetworkSingle-batch inference
(steps/s)
Parallel forward
(tokens/s)
BandwidthRound-trip
latency
Sequence lengthBatch size
1282048164
Offloading, max. possible speed on 1x A100 1
256 Gbit/s0.180.182.7170.3
128 Gbit/s0.090.092.4152.8
Petals on 14 heterogeneous servers across Europe and North America 2
Real world0.830.7932.6179.4
Petals on 3 servers, with one A100 each 3
1 Gbit/s< 5 ms1.711.5470.0253.6
100 Mbit/s< 5 ms1.661.4956.4182.0
100 Mbit/s100 ms1.231.1119.7112.2
- -1 **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch. - -2 **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 100–1000 Mbit/s. 4 servers operate from under firewalls. - -3 **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU. - -We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf). - -## 🛠️ Contributing +### Benchmarks + +Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf). + +### 🛠️ Contributing Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing. -## 📜 Citation +### 📜 Citation Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel. [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188) From ae19b650959fadd837f4db799f9ec35011199506 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 8 Oct 2023 22:09:46 +0300 Subject: [PATCH 05/11] Add position_ids argument to DistributedFalconModel (#525) --- src/petals/models/falcon/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py index 32c0b6f..296214d 100644 --- a/src/petals/models/falcon/model.py +++ b/src/petals/models/falcon/model.py @@ -47,6 +47,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -68,6 +69,9 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" + assert ( + position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() + ), f"Non-consecutive position_ids are not supported, {position_ids=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" From 47d50e1e2938f8a0174caf670b25dea5345c6830 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 23 Oct 2023 05:26:40 +0600 Subject: [PATCH 06/11] Improve default arguments for clients and servers (#530) This PR updates multiple default arguments in clients and servers: 1. **The client defaults to `torch_dtype=torch.float32` instead of `torch_dtype="auto"`.** The old default was to load weights in the dtype they are saved in (usually bfloat16/float16), which caused issues when the client was run on CPU (the default unless you call `.cuda()`). Specifically, bfloat16 is slow on most CPUs (unless a CPU supports AVX512) and float16 can't be run natively and leads to an exception. This default was a legacy of the earliest Petals versions designed to run BLOOM - its embeddings were so big that they didn't fit into RAM in float32 (e.g., in Colab). The newer models don't have this issue. In contrast, the new default leads to good speed on all CPUs and is consistent with PyTorch and HF Transformers. Also, the client now shows "bfloat16 on non-AVX512 CPU" in all cases (previously this warning was shown only if the machine has enough RAM to fit float32 weights, which could hide the crucial reason of inference being slow). **Note:** This change is backward-incompatible, so we have to increase at least the minor package version (2.2.0 -> 2.3.0.dev0). 2. **The server uses 2x smaller `--attn_cache_tokens`.** The old default led to loading 39 (out of 80) or 78 (out of 80) blocks for popular models on some GPU types, which visibly slowed down inference due to an excess network hop. It was also leaving too much cache, so that inference slowed down much before the cache is used. The new default leads to more efficient block layouts and makes the inference routing algorithm choose alternative paths through other servers when a particular server already has enough active inference sessions (= its cache is full). 3. **The client's max number of retries can be limited by the `PETALS_MAX_RETRIES` env var.** This is to limit `ClientConfig.max_retries` in tests, so we see tracebacks instead of retrying indefinitely in case of errors. --- .github/workflows/run-tests.yaml | 3 +++ src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 6 +++--- src/petals/client/config.py | 6 +++++- src/petals/client/from_pretrained.py | 10 +--------- src/petals/client/lm_head.py | 12 +++++------- src/petals/server/server.py | 2 +- 7 files changed, 19 insertions(+), 22 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 05cebdd..b9dcc01 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -102,6 +102,9 @@ jobs: export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES + # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely + export PETALS_MAX_RETRIES=10 + pytest tests --durations=0 --durations-min=1.0 -v # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index f513f65..8671fc2 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,7 +17,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.2.0" +__version__ = "2.3.0.dev0" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 94f5c2e..5208438 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -70,17 +70,17 @@ def main(): parser.add_argument('--inference_max_length', type=int, default=None, help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all operations (in total tokens)') parser.add_argument('--max_batch_size', type=int, default=None, help='The total number of tokens in the same batch will not exceed this value. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') parser.add_argument('--attn_cache_tokens', type=int, default=None, help='The number of past attention key/value pairs that will be stored between inference steps. ' - 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') diff --git a/src/petals/client/config.py b/src/petals/client/config.py index e255024..a2f8f42 100644 --- a/src/petals/client/config.py +++ b/src/petals/client/config.py @@ -1,10 +1,14 @@ import dataclasses +import os from typing import Optional, Sequence, Union from hivemind import PeerID from petals.constants import PUBLIC_INITIAL_PEERS +_max_retries = os.getenv("PETALS_MAX_RETRIES") +DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None + @dataclasses.dataclass class ClientConfig: @@ -21,7 +25,7 @@ class ClientConfig: request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds - max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf) min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds diff --git a/src/petals/client/from_pretrained.py b/src/petals/client/from_pretrained.py index f2c88d2..4b9d8e5 100644 --- a/src/petals/client/from_pretrained.py +++ b/src/petals/client/from_pretrained.py @@ -6,7 +6,6 @@ import tempfile from contextvars import ContextVar from typing import List, Optional, Tuple, Union -import torch from hivemind.utils.logging import get_logger from transformers import BloomPreTrainedModel, modeling_utils @@ -22,21 +21,14 @@ class FromPretrainedMixin: model_name_or_path: Union[str, os.PathLike, None], *args, low_cpu_mem_usage: Optional[bool] = None, - torch_dtype: Optional[Union[str, torch.dtype]] = None, **kwargs, ): model_name_or_path = get_compatible_model_repo(model_name_or_path) if low_cpu_mem_usage is None: low_cpu_mem_usage = True - if torch_dtype is None: - # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, - # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. - torch_dtype = "auto" with ignore_keys(cls._keys_to_ignore_on_load_unexpected): - return super().from_pretrained( - model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs - ) + return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( "low_cpu_mem_usage(`bool`, *optional*)", diff --git a/src/petals/client/lm_head.py b/src/petals/client/lm_head.py index cbea89d..bc0e293 100644 --- a/src/petals/client/lm_head.py +++ b/src/petals/client/lm_head.py @@ -1,8 +1,7 @@ import dataclasses import platform -from typing import Optional, Union +from typing import Union -import psutil import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -68,11 +67,10 @@ class LMHead(nn.Module): assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" if not self._bf16_warning_shown: - if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: - logger.warning( - "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " - "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" - ) + logger.warning( + "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " + "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" + ) self._bf16_warning_shown = True hidden_states = hidden_states.float() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 82388aa..45884e3 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -203,7 +203,7 @@ class Server: # For attention cache in GPU or RAM if attn_cache_tokens is None: - attn_cache_tokens = 32768 if is_multiquery_attn else 8192 + attn_cache_tokens = 16384 if is_multiquery_attn else 4096 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype) From 82a97d6e9ea18e79639f375f124c4c83fe1933e8 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 23 Oct 2023 20:13:13 +0400 Subject: [PATCH 07/11] Fix beam search in GPU clients (#531) Fixes #503. --- .github/workflows/run-tests.yaml | 22 ++++++--------- src/petals/client/inference_session.py | 38 ++++++++++---------------- 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index b9dcc01..74b731d 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -48,7 +48,6 @@ jobs: export MODEL_NAME="${{ matrix.model }}" export REF_NAME="${{ matrix.model }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" - export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) @@ -61,27 +60,25 @@ jobs: until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ - --mean_balance_check_period 10 \ - --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log & + export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \ + --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS" + export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" + + $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log & SERVER1_PID=$! # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there sleep 10 # wait for the 1st server to choose blocks - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \ - --identity_path tests/server2.id \ - --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log & + $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log & SERVER2_PID=$! - python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \ - --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ - --initial_peers $INITIAL_PEERS --throughput auto &> server3.log & + $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \ + --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log & SERVER3_PID=$! # ^-- chunking test - python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \ - --initial_peers $INITIAL_PEERS --throughput auto &> server4.log & + $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log & SERVER4_PID=$! # ^-- tensor parallelism test (not compatible with adapters yet) @@ -121,4 +118,3 @@ jobs: # [Step 4] Clean up kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID - echo "Done!" diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 28d3632..34d24c7 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -84,12 +84,7 @@ class _ServerInferenceSession: break # this message means "done sending" def step( - self, - inputs: torch.Tensor, - prompts: Optional[torch.Tensor] = None, - hypo_ids: Optional[torch.Tensor] = None, - *, - step_id: str, + self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -114,21 +109,6 @@ class _ServerInferenceSession: else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further - if prompts is None or is_dummy(prompts): - prompts = DUMMY - else: - assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" - assert prompts.shape[0] == self.num_blocks - assert prompts.shape[1] in (inputs.shape[0], 1) - assert prompts.shape[2] <= inputs.shape[1] - assert prompts.shape[3] == inputs.shape[2] - - if hypo_ids is None or is_dummy(hypo_ids): - hypo_ids = DUMMY_INT64 - else: - assert len(hypo_ids) == len(inputs) - assert hypo_ids.dtype == torch.int64 - # serialize inputs and put them into the queue input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids) @@ -275,7 +255,9 @@ class InferenceSession: assert not self._closed and not self._server_sessions return self - def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + def step( + self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") @@ -285,11 +267,21 @@ class InferenceSession: else: assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks + assert prompts.shape[1] in (inputs.shape[0], 1) + assert prompts.shape[2] <= inputs.shape[1] + assert prompts.shape[3] == inputs.shape[2] + + if hypo_ids is None or is_dummy(hypo_ids): + hypo_ids = DUMMY_INT64 + else: + assert len(hypo_ids) == len(inputs) + assert hypo_ids.dtype == torch.int64 inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() + hypo_ids = hypo_ids.cpu() step_id = str(uuid.uuid4()) n_input_tokens = inputs.shape[1] @@ -310,7 +302,7 @@ class InferenceSession: server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs + inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id ) server_idx += 1 From dcce43670f79001f169a5d962d887ebb03530393 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 7 Nov 2023 18:19:19 +0300 Subject: [PATCH 08/11] Hotfix: set transformers version <=4.34 temporarily (#538) * fix transformers version for now Co-authored-by: horik --- setup.cfg | 2 +- src/petals/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index c8dbc9a..e2be3d1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = accelerate>=0.22.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py + transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet hivemind==1.1.10.post2 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 8671fc2..e17b4a1 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -22,8 +22,8 @@ __version__ = "2.3.0.dev0" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( - version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0") - ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0" + version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0") + ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0" def _override_bfloat16_mode_default(): From 25a0796b3946beebf6888f96050ac143a392bc9c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 7 Nov 2023 19:05:54 +0300 Subject: [PATCH 09/11] Hotfix: require peft version 0.5.0 (#539) Peft: strict version check for now Co-authored-by: horik --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index e2be3d1..ef35f84 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = cpufeature>=0.2.0; platform_machine == "x86_64" packaging>=20.9 sentencepiece>=0.1.99 - peft>=0.5.0 + peft==0.5.0 safetensors>=0.3.1 Dijkstar>=2.6.0 From 03cbe90234ccd4e3cf749d9370f53bea2a1dcb67 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 14 Nov 2023 18:14:19 +0100 Subject: [PATCH 10/11] Optimize LLaMA for inference (#513) * Optimize LLaMa for inference * Fix model type detection in tests --- src/petals/models/llama/block.py | 219 +++++++++++++++++++++++++++++-- src/petals/utils/cuda_graphs.py | 76 +++++++++++ tests/test_optimized_layers.py | 98 +++++++++++++- 3 files changed, 378 insertions(+), 15 deletions(-) create mode 100644 src/petals/utils/cuda_graphs.py diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 55f659a..a8d433d 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -3,13 +3,219 @@ LLaMA intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py See commit history for authorship. """ +import math from typing import Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + repeat_kv, + rotate_half, +) +from petals.utils.cuda_graphs import make_inference_graphed_callable + + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class OptimizedLlamaAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._rotary_graph = None + + def _optimized_apply_rotary(self, query_states, key_states, cos, sin): + if self._rotary_graph is None: + self._rotary_graph = make_inference_graphed_callable( + apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin) + ) + return self._rotary_graph(query_states, key_states, cos, sin) -class WrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert not output_attentions + assert position_ids is None + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[:, :, kv_seq_len - q_len :] + sin = sin[:, :, kv_seq_len - q_len :] + + if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = OptimizedLlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.pre_attn_graph = None + self.post_attn_graph = None + + def _optimized_input_layernorm(self, hidden_states): + if self.pre_attn_graph is None: + self.pre_attn_graph = make_inference_graphed_callable( + self.input_layernorm.forward, sample_args=(hidden_states,) + ) + return self.pre_attn_graph(hidden_states) + + def _optimized_output_layernorm(self, hidden_states): + if self.post_attn_graph is None: + self.post_attn_graph = make_inference_graphed_callable( + self.post_attention_layernorm.forward, sample_args=(hidden_states,) + ) + return self.post_attn_graph(hidden_states) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + hidden_states = self._optimized_input_layernorm(hidden_states) + else: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + hidden_states = self._optimized_output_layernorm(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, @@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer): seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) - if position_ids is None: - device = hidden_states.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + assert position_ids is None # embed positions if attention_mask is None: diff --git a/src/petals/utils/cuda_graphs.py b/src/petals/utils/cuda_graphs.py new file mode 100644 index 0000000..216ecf1 --- /dev/null +++ b/src/petals/utils/cuda_graphs.py @@ -0,0 +1,76 @@ +import torch +from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten + + +def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3): + """Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass""" + assert not isinstance(callable, torch.nn.Module) + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." + ) + + flatten_arg, _ = _tree_flatten(sample_args) + flatten_sample_args = tuple(flatten_arg) + assert all( + isinstance(arg, torch.Tensor) for arg in flatten_arg + ), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed." + + len_user_args = len(sample_args) + static_input_surface = flatten_sample_args + + graph = torch.cuda.CUDAGraph() + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(num_warmup_iters): + outputs, _ = _tree_flatten(callable(*sample_args)) + del outputs + torch.cuda.current_stream().wait_stream(s) + + # Capture forward graph + with torch.cuda.graph(graph): + outputs = callable(*sample_args) + + flatten_outputs, output_unflatten_spec = _tree_flatten(outputs) + static_outputs = tuple(flatten_outputs) + + def make_graphed_function( + graph, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + ): + def replay_graph(*inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + flatten_user_args, _ = _tree_flatten(user_args) + out = replay_graph(*flatten_user_args) + return _tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callable + graphed = make_graphed_function( + graph, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + ) + return graphed diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 5baa1a2..84cbfff 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import pytest import torch from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block @@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): return state -@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") +class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + batch_size, seq_length, _ = hidden_states.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + past_key_value = layer_past + if past_key_value is not None: + past_key_values_length = past_key_value[0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) + + if position_ids is None: + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = LlamaModel._prepare_decoder_attention_mask( + None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_llama_to_bloom( + present_key_value, batch_size, seq_length_with_past + ) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_llama( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + key_states = key_states.permute(0, 2, 1) + key_states = key_states.view( + batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + value_states = value_states.view(*key_states.shape) + return (key_states, value_states) + + def _reorder_cache_from_llama_to_bloom( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + value_states = value_states.view( + batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + key_states = key_states.view(*value_states.shape) + key_states = key_states.permute(0, 2, 1) + return (key_states, value_states) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) @pytest.mark.forked -def test_falcon(device): +def test_optimized_block(device): if device == "cuda:0" and not torch.cuda.is_available(): pytest.skip("CUDA tests can be run only in CUDA-enabled setups") @@ -108,11 +190,17 @@ def test_falcon(device): quant_type = QuantType.NONE block = config.block_class(config).to(dtype) - block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + + if config.model_type == "falcon": + unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) + elif config.model_type == "llama": + unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype) + else: + pytest.skip(f"This test is not applicable to {config.model_type} models") - unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) unopt_block = convert_block( - unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True ) unopt_block.load_state_dict(block.state_dict()) From d59c15c5787488005f12162a40930dd284551e02 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 16 Nov 2023 06:12:30 +0300 Subject: [PATCH 11/11] Bump version for inference diagnostics (#543) bump version for inference diagnostics --- src/petals/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index e17b4a1..1af8bf9 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,7 +17,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.3.0.dev0" +__version__ = "2.3.0.dev1" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):