diff --git a/README.md b/README.md index 5555f75..7a727df 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Generate text using distributed BLOOM and fine-tune it for your own tasks: ```python -from petals.client import DistributedBloomForCausalLM +from petals import DistributedBloomForCausalLM # Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune") @@ -68,13 +68,13 @@ Check out more tutorials: 📜  Read paper

-### 📋 Model's terms of use +### 🔒 Privacy and security -Before building your own application that runs a language model with Petals, please make sure that you are familiar with the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license). +The Petals public swarm is designed for research and academic use. **Please do not use the public swarm to process sensitive data.** We ask for that because it is an open network, and it is technically possible for peers serving model layers to recover input data and model outputs or modify them in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process your data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). -### 🔒 Privacy and security +### 📋 Model's terms of use -**If you work with sensitive data, do not use the public swarm.** This is important because it's technically possible for peers serving model layers to recover input data and model outputs, or modify the outputs in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process this data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). +Before building your own application that runs a language model with Petals, please check out the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license). ## FAQ diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb index 868299b..9e438cc 100644 --- a/examples/prompt-tuning-personachat.ipynb +++ b/examples/prompt-tuning-personachat.ipynb @@ -48,7 +48,7 @@ "outputs": [], "source": [ "import os\n", - " \n", + "\n", "import torch\n", "import transformers\n", "import wandb\n", @@ -58,8 +58,7 @@ "from torch.utils.data import DataLoader\n", "from transformers import BloomTokenizerFast, get_scheduler\n", "\n", - "# Import a Petals model\n", - "from petals.client.remote_model import DistributedBloomForCausalLM" + "from petals import DistributedBloomForCausalLM" ] }, { diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index dce7766..41d37bb 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -48,22 +48,19 @@ "outputs": [], "source": [ "import os\n", - " \n", - "import torch\n", - "import transformers\n", - "import wandb\n", "\n", + "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", - "\n", + "import transformers\n", + "import wandb\n", "from datasets import load_dataset, load_metric\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "from transformers import BloomTokenizerFast, get_scheduler\n", "\n", - "# Import a Petals model\n", - "from petals.client.remote_model import DistributedBloomForSequenceClassification" + "from petals import DistributedBloomForSequenceClassification" ] }, { diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 066180b..9998543 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1,5 +1,6 @@ -import petals.utils.logging +from petals.client import * +from petals.utils.logging import initialize_logs as _initialize_logs __version__ = "1.0alpha1" -petals.utils.logging.initialize_logs() +_initialize_logs() diff --git a/src/petals/client/__init__.py b/src/petals/client/__init__.py index 93fc8a6..b728962 100644 --- a/src/petals/client/__init__.py +++ b/src/petals/client/__init__.py @@ -1,5 +1,10 @@ from petals.client.inference_session import InferenceSession -from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel +from petals.client.remote_model import ( + DistributedBloomConfig, + DistributedBloomForCausalLM, + DistributedBloomForSequenceClassification, + DistributedBloomModel, +) from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/petals/server/runtime.py b/src/petals/server/runtime.py deleted file mode 100644 index 11547aa..0000000 --- a/src/petals/server/runtime.py +++ /dev/null @@ -1,198 +0,0 @@ -import multiprocessing as mp -import multiprocessing.pool -import threading -from collections import defaultdict -from itertools import chain -from queue import SimpleQueue -from selectors import EVENT_READ, DefaultSelector -from statistics import mean -from time import time -from typing import Dict, NamedTuple, Optional - -import torch -from hivemind.moe.server.module_backend import ModuleBackend -from hivemind.utils import get_logger -from prefetch_generator import BackgroundGenerator - -logger = get_logger(__name__) - - -class Runtime(threading.Thread): - """ - A group of processes that processes incoming requests for multiple module backends on a shared device. - Runtime is usually created and managed by Server, humans need not apply. - - For debugging, you can start runtime manually with .start() or .run() - - >>> module_backends = {'block_uid': ModuleBackend(**kwargs)} - >>> runtime = Runtime(module_backends) - >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run() - >>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools - >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs) - >>> print("Returned:", future.result()) - >>> runtime.shutdown() - - :param module_backends: a dict [block uid -> ModuleBackend] - :param prefetch_batches: form up to this many batches in advance - :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads - :param device: if specified, moves all blocks and data to this device via .to(device=device). - If you want to manually specify devices for each block (in their forward pass), leave device=None (default) - - :param stats_report_interval: interval to collect and log statistics about runtime performance - """ - - SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED" - - def __init__( - self, - module_backends: Dict[str, ModuleBackend], - prefetch_batches: int = 1, - sender_threads: int = 1, - device: torch.device = None, - stats_report_interval: Optional[int] = None, - ): - super().__init__() - self.module_backends = module_backends - self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values()))) - self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads - self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False) - self.shutdown_trigger = mp.Event() - self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches - - self.stats_report_interval = stats_report_interval - if self.stats_report_interval is not None: - self.stats_reporter = StatsReporter(self.stats_report_interval) - - def run(self): - for pool in self.pools: - if not pool.is_alive(): - pool.start() - if self.device is not None: - for backend in self.module_backends.values(): - backend.module.to(self.device) - - with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool: - try: - self.ready.set() - if self.stats_report_interval is not None: - self.stats_reporter.start() - logger.info("Started") - - batch_iterator = self.iterate_minibatches_from_pools() - if self.prefetch_batches > 0: - batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches) - - for pool, batch_index, batch in batch_iterator: - logger.debug(f"Processing batch {batch_index} from pool {pool.name}") - - start = time() - try: - outputs = pool.process_func(*batch) - output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs]) - - batch_processing_time = time() - start - - batch_size = outputs[0].size(0) - logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}") - - if self.stats_report_interval is not None: - self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time) - - except KeyboardInterrupt: - raise - except BaseException as exception: - logger.exception(f"Caught {exception}, attempting to recover") - output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception]) - - finally: - if not self.shutdown_trigger.is_set(): - self.shutdown() - - def shutdown(self): - """Gracefully terminate a running runtime.""" - logger.info("Shutting down") - self.ready.clear() - - if self.stats_report_interval is not None: - self.stats_reporter.stop.set() - self.stats_reporter.join() - - logger.debug("Terminating pools") - for pool in self.pools: - if pool.is_alive(): - pool.shutdown() - logger.debug("Pools terminated") - - # trigger background thread to shutdown - self.shutdown_send.send(self.SHUTDOWN_TRIGGER) - self.shutdown_trigger.set() - - def iterate_minibatches_from_pools(self, timeout=None): - """ - Chooses pool according to priority, then copies exposed batch and frees the buffer - """ - with DefaultSelector() as selector: - for pool in self.pools: - selector.register(pool.batch_receiver, EVENT_READ, pool) - selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER) - - while True: - # wait until at least one batch_receiver becomes available - logger.debug("Waiting for inputs from task pools") - ready_fds = selector.select() - ready_objects = {key.data for (key, events) in ready_fds} - if self.SHUTDOWN_TRIGGER in ready_objects: - break # someone asked us to shutdown, break from the loop - - logger.debug("Choosing the pool with first priority") - - pool = min(ready_objects, key=lambda pool: pool.priority) - - logger.debug(f"Loading batch from {pool.name}") - batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device) - logger.debug(f"Loaded batch from {pool.name}") - yield pool, batch_index, batch_tensors - - -BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float))) - - -class StatsReporter(threading.Thread): - def __init__(self, report_interval: int): - super().__init__() - self.report_interval = report_interval - self.stop = threading.Event() - self.stats_queue = SimpleQueue() - - def run(self): - while not self.stop.wait(self.report_interval): - pool_batch_stats = defaultdict(list) - while not self.stats_queue.empty(): - pool_uid, batch_stats = self.stats_queue.get() - pool_batch_stats[pool_uid].append(batch_stats) - - total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values()) - logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:") - for pool_uid, pool_stats in pool_batch_stats.items(): - total_batches = len(pool_stats) - total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats) - avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats) - total_time = sum(batch_stats.processing_time for batch_stats in pool_stats) - batches_to_time = total_batches / total_time - batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch") - - examples_to_time = total_examples / total_time - example_performance = f"{examples_to_time:.2f} " + ( - "examples/s" if examples_to_time > 1 else "s/example" - ) - - logger.info( - f"{pool_uid}: " - f"{total_batches} batches ({batch_performance}), " - f"{total_examples} examples ({example_performance}), " - f"avg batch size {avg_batch_size:.2f}" - ) - - def report_stats(self, pool_uid, batch_size, processing_time): - batch_stats = BatchStats(batch_size, processing_time) - self.stats_queue.put_nowait((pool_uid, batch_stats)) diff --git a/src/petals/utils/logging.py b/src/petals/utils/logging.py index e4732f2..6fe099f 100644 --- a/src/petals/utils/logging.py +++ b/src/petals/utils/logging.py @@ -25,7 +25,10 @@ def initialize_logs(): os.environ["HIVEMIND_COLORS"] = "True" importlib.reload(hm_logging) - hm_logging.get_logger().handlers.clear() # Remove extra default handlers on Colab + # Remove log handlers from previous import of hivemind.utils.logging and extra handlers on Colab + hm_logging.get_logger().handlers.clear() + hm_logging.get_logger("hivemind").handlers.clear() + hm_logging.use_hivemind_log_handler("in_root_logger") # We suppress asyncio error logs by default since they are mostly not relevant for the end user, diff --git a/tests/test_priority_pool.py b/tests/test_priority_pool.py index cd946ee..2623bb1 100644 --- a/tests/test_priority_pool.py +++ b/tests/test_priority_pool.py @@ -3,8 +3,8 @@ import time import pytest import torch +from hivemind.moe.server.runtime import Runtime -from petals.server.runtime import Runtime from petals.server.task_pool import PrioritizedTaskPool