From 8666653cf562519cf38e50ccd6712c3f8ae7908e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 22 Jul 2023 18:27:58 +0400 Subject: [PATCH] Fix routing through relay, default network RPS, --token, logging, readme (#399) * Hide GeneratorExit in _iterate_inference_steps() * Update README.md about `--public_name` * Use .from_pretrained(..., use_auth_token=token) instead of token=token until it's fully supported across HF libs * Use default network speed 25 Mbit/s * Apply relay penalty in max-throughput routing * Replace RPS with "tokens/sec per block" in logs * Increase default expiration --- README.md | 8 ++- src/petals/client/routing/sequence_manager.py | 12 +++- src/petals/server/from_pretrained.py | 2 +- src/petals/server/handler.py | 2 +- src/petals/server/server.py | 4 +- src/petals/server/throughput.py | 65 +++++++++---------- src/petals/utils/auto_config.py | 8 ++- tests/scripts/remove_old_models.py | 25 ------- 8 files changed, 58 insertions(+), 68 deletions(-) delete mode 100644 tests/scripts/remove_old_models.py diff --git a/README.md b/README.md index 784d0f1..e4bca6e 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,13 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat... ### Connect your GPU and increase Petals capacity +Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models! + Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+): ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install --upgrade petals +pip install git+https://github.com/bigscience-workshop/petals python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b ``` @@ -55,6 +57,8 @@ This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.c 💬 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)! +🏆 If you host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks! You can specify them with `--public_name YOUR_NAME`. We will show them once your server loads all blocks. + ### Check out tutorials, examples, and more Basic tutorials: @@ -97,7 +101,7 @@ Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/d ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -pip install --upgrade petals +pip install git+https://github.com/bigscience-workshop/petals ``` If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index f0b0ce0..c980412 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -291,7 +291,9 @@ class RemoteSequenceManager: # This is okay since false positives are more costly than false negatives here. return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left - def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: + def _make_sequence_with_max_throughput( + self, start_index: int, end_index: int, *, relay_penalty: float = 0.5 + ) -> List[RemoteSpanInfo]: span_sequence = [] current_index = start_index while current_index < end_index: @@ -299,7 +301,13 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) + span_weights = np.array( + [ + span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty) + for span in candidate_spans + ], + dtype=np.float64, + ) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 2a2560b..bfbf03e 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -40,7 +40,7 @@ def load_pretrained_block( max_disk_space: Optional[int] = None, ) -> nn.Module: if config is None: - config = AutoDistributedConfig.from_pretrained(model_name, token=token) + config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 5d0a3d4..d3776de 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -347,7 +347,7 @@ class TransformerConnectionHandler(ConnectionHandler): anext_task.cancel() get_push_task.cancel() return - except: + except Exception: logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) raise diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 5cdca46..6d5c293 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -104,7 +104,7 @@ class Server: self.block_config = AutoDistributedConfig.from_pretrained( converted_model_name_or_path, - token=token, + use_auth_token=token, revision=revision, ) @@ -117,7 +117,7 @@ class Server: self.dht_prefix = dht_prefix if expiration is None: - expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) + expiration = max(3 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) self.expiration = expiration self.request_timeout = request_timeout diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index d92355e..9e2ad6f 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -96,7 +96,7 @@ def get_server_throughput( throughput = throughput_info["forward_rps"] / average_blocks_used throughput = min(throughput, throughput_info.get("network_rps", math.inf)) throughput_info["throughput"] = throughput - logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks") + logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks") return throughput_info @@ -109,13 +109,10 @@ def measure_throughput_info( quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], ) -> Dict[str, float]: - """Measure network and compute throughput in forward pass tokens per second""" - logger.info( "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" ) - - throughput_info = { + return { "inference_rps": measure_compute_rps( config, device, @@ -136,37 +133,39 @@ def measure_throughput_info( n_steps=10, inference=False, ), + "network_rps": measure_network_rps(config), } - try: - throughput_info["network_rps"] = measure_network_rps(config) - except Exception as e: - logger.info(f"Network throughput is not available: {e}") - return throughput_info - -def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]: - pipe_recv, pipe_send = mp.Pipe(duplex=False) - process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) - process.start() - - if not pipe_recv.poll(timeout): - process.terminate() - raise RuntimeError(f"speedtest did not finish in {timeout} seconds") - network_info = pipe_recv.recv() - if "exception" in network_info: - raise RuntimeError(f"speedtest failed: {network_info['exception']}") +def measure_network_rps( + config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 25e6 +) -> Optional[float]: bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward - network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request - if network_rps == 0: - raise RuntimeError("speedtest has returned network_rps == 0") - - logger.info( - f"Network throughput: {network_rps:.1f} RPS " - f"({network_info['download'] / 1e6:.2f} Mbit/s on download, " - f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)" - ) - return network_rps + try: + pipe_recv, pipe_send = mp.Pipe(duplex=False) + process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) + process.start() + + if not pipe_recv.poll(timeout): + process.terminate() + raise RuntimeError(f"speedtest did not finish in {timeout} seconds") + network_info = pipe_recv.recv() + if "exception" in network_info: + raise RuntimeError(f"speedtest failed: {network_info['exception']}") + + network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request + if network_rps == 0: + raise RuntimeError("speedtest has returned network_rps == 0") + + logger.info( + f"Network throughput: {network_rps:.1f} tokens/sec " + f"({network_info['download'] / 1e6:.2f} Mbit/s on download, " + f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)" + ) + return network_rps + except RuntimeError as e: + logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s") + return default_speed / bits_per_request def _measure_bits_per_second(pipe_send: mp.Pipe): @@ -215,7 +214,7 @@ def measure_compute_rps( devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) logger.info( - f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block " + f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block " f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})" ) return device_rps diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index 13c7298..70f37a3 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -31,8 +31,12 @@ class _AutoDistributedBase: @classmethod def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig: - if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs: - kwargs["token"] = True + if ( + always_needs_auth(model_name_or_path) + and kwargs.get("token") is None + and kwargs.get("use_auth_token") is None + ): + kwargs["use_auth_token"] = True config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs) if config.model_type not in _CLASS_MAPPING: diff --git a/tests/scripts/remove_old_models.py b/tests/scripts/remove_old_models.py deleted file mode 100644 index 598fb3b..0000000 --- a/tests/scripts/remove_old_models.py +++ /dev/null @@ -1,25 +0,0 @@ -import argparse -from datetime import datetime - -from huggingface_hub import delete_repo, list_models - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Remove old testing models from HF hub") - parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained") - parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60) - parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") - parser.add_argument("--dry_run", action="store_true") - - args = parser.parse_args() - - for model in list_models(author=args.author, full=True): - last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ") - - if model.modelId.endswith("-main") or "/test-" not in model.modelId: - continue # remove only test models - - if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated: - if args.dry_run: - print(f"{model.modelId} can be deleted") - else: - delete_repo(repo_id=model.modelId, token=args.use_auth_token)