Fix issues related to `petals` as a module (#159)

1. Added `from petals.client import *` to `petals/__init__.py`, so you can write just that:

    ```python
    from petals import DistributedBloomForCausalLM
    ```

    I didn't do the same with server, since its classes are supposed to by used by `petals.cli.run_server`, not end-users. Though it's still possible to do `from petals.server.smth import smth` if necessary.

2. Fixed one more logging issue: log lines from hivemind were shown twice due to a bug in #156.

3. Removed unused `runtime.py`, since the server actually uses `hivemind.moe.Runtime`, and `runtime.py` has no significant changes comparing to it.
pull/160/head
Alexander Borzunov 1 year ago committed by GitHub
parent 91898c3c90
commit 523a7cad33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@
Generate text using distributed BLOOM and fine-tune it for your own tasks:
```python
from petals.client import DistributedBloomForCausalLM
from petals import DistributedBloomForCausalLM
# Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet
model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune")
@ -68,13 +68,13 @@ Check out more tutorials:
📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
</p>
### 📋 Model's terms of use
### 🔒 Privacy and security
Before building your own application that runs a language model with Petals, please make sure that you are familiar with the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
The Petals public swarm is designed for research and academic use. **Please do not use the public swarm to process sensitive data.** We ask for that because it is an open network, and it is technically possible for peers serving model layers to recover input data and model outputs or modify them in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process your data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
### 🔒 Privacy and security
### 📋 Model's terms of use
**If you work with sensitive data, do not use the public swarm.** This is important because it's technically possible for peers serving model layers to recover input data and model outputs, or modify the outputs in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process this data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
Before building your own application that runs a language model with Petals, please check out the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
## FAQ

@ -48,7 +48,7 @@
"outputs": [],
"source": [
"import os\n",
" \n",
"\n",
"import torch\n",
"import transformers\n",
"import wandb\n",
@ -58,8 +58,7 @@
"from torch.utils.data import DataLoader\n",
"from transformers import BloomTokenizerFast, get_scheduler\n",
"\n",
"# Import a Petals model\n",
"from petals.client.remote_model import DistributedBloomForCausalLM"
"from petals import DistributedBloomForCausalLM"
]
},
{

@ -48,22 +48,19 @@
"outputs": [],
"source": [
"import os\n",
" \n",
"import torch\n",
"import transformers\n",
"import wandb\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import transformers\n",
"import wandb\n",
"from datasets import load_dataset, load_metric\n",
"from tqdm import tqdm\n",
"from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader\n",
"from transformers import BloomTokenizerFast, get_scheduler\n",
"\n",
"# Import a Petals model\n",
"from petals.client.remote_model import DistributedBloomForSequenceClassification"
"from petals import DistributedBloomForSequenceClassification"
]
},
{

@ -1,5 +1,6 @@
import petals.utils.logging
from petals.client import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "1.0alpha1"
petals.utils.logging.initialize_logs()
_initialize_logs()

@ -1,5 +1,10 @@
from petals.client.inference_session import InferenceSession
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from petals.client.remote_model import (
DistributedBloomConfig,
DistributedBloomForCausalLM,
DistributedBloomForSequenceClassification,
DistributedBloomModel,
)
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -1,198 +0,0 @@
import multiprocessing as mp
import multiprocessing.pool
import threading
from collections import defaultdict
from itertools import chain
from queue import SimpleQueue
from selectors import EVENT_READ, DefaultSelector
from statistics import mean
from time import time
from typing import Dict, NamedTuple, Optional
import torch
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from prefetch_generator import BackgroundGenerator
logger = get_logger(__name__)
class Runtime(threading.Thread):
"""
A group of processes that processes incoming requests for multiple module backends on a shared device.
Runtime is usually created and managed by Server, humans need not apply.
For debugging, you can start runtime manually with .start() or .run()
>>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
>>> runtime = Runtime(module_backends)
>>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
>>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
>>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
>>> print("Returned:", future.result())
>>> runtime.shutdown()
:param module_backends: a dict [block uid -> ModuleBackend]
:param prefetch_batches: form up to this many batches in advance
:param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
:param device: if specified, moves all blocks and data to this device via .to(device=device).
If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
:param stats_report_interval: interval to collect and log statistics about runtime performance
"""
SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
def __init__(
self,
module_backends: Dict[str, ModuleBackend],
prefetch_batches: int = 1,
sender_threads: int = 1,
device: torch.device = None,
stats_report_interval: Optional[int] = None,
):
super().__init__()
self.module_backends = module_backends
self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
self.shutdown_trigger = mp.Event()
self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
self.stats_report_interval = stats_report_interval
if self.stats_report_interval is not None:
self.stats_reporter = StatsReporter(self.stats_report_interval)
def run(self):
for pool in self.pools:
if not pool.is_alive():
pool.start()
if self.device is not None:
for backend in self.module_backends.values():
backend.module.to(self.device)
with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
try:
self.ready.set()
if self.stats_report_interval is not None:
self.stats_reporter.start()
logger.info("Started")
batch_iterator = self.iterate_minibatches_from_pools()
if self.prefetch_batches > 0:
batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
for pool, batch_index, batch in batch_iterator:
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
start = time()
try:
outputs = pool.process_func(*batch)
output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
batch_processing_time = time() - start
batch_size = outputs[0].size(0)
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
except KeyboardInterrupt:
raise
except BaseException as exception:
logger.exception(f"Caught {exception}, attempting to recover")
output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
finally:
if not self.shutdown_trigger.is_set():
self.shutdown()
def shutdown(self):
"""Gracefully terminate a running runtime."""
logger.info("Shutting down")
self.ready.clear()
if self.stats_report_interval is not None:
self.stats_reporter.stop.set()
self.stats_reporter.join()
logger.debug("Terminating pools")
for pool in self.pools:
if pool.is_alive():
pool.shutdown()
logger.debug("Pools terminated")
# trigger background thread to shutdown
self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
self.shutdown_trigger.set()
def iterate_minibatches_from_pools(self, timeout=None):
"""
Chooses pool according to priority, then copies exposed batch and frees the buffer
"""
with DefaultSelector() as selector:
for pool in self.pools:
selector.register(pool.batch_receiver, EVENT_READ, pool)
selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
while True:
# wait until at least one batch_receiver becomes available
logger.debug("Waiting for inputs from task pools")
ready_fds = selector.select()
ready_objects = {key.data for (key, events) in ready_fds}
if self.SHUTDOWN_TRIGGER in ready_objects:
break # someone asked us to shutdown, break from the loop
logger.debug("Choosing the pool with first priority")
pool = min(ready_objects, key=lambda pool: pool.priority)
logger.debug(f"Loading batch from {pool.name}")
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
logger.debug(f"Loaded batch from {pool.name}")
yield pool, batch_index, batch_tensors
BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
class StatsReporter(threading.Thread):
def __init__(self, report_interval: int):
super().__init__()
self.report_interval = report_interval
self.stop = threading.Event()
self.stats_queue = SimpleQueue()
def run(self):
while not self.stop.wait(self.report_interval):
pool_batch_stats = defaultdict(list)
while not self.stats_queue.empty():
pool_uid, batch_stats = self.stats_queue.get()
pool_batch_stats[pool_uid].append(batch_stats)
total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
for pool_uid, pool_stats in pool_batch_stats.items():
total_batches = len(pool_stats)
total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
batches_to_time = total_batches / total_time
batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
examples_to_time = total_examples / total_time
example_performance = f"{examples_to_time:.2f} " + (
"examples/s" if examples_to_time > 1 else "s/example"
)
logger.info(
f"{pool_uid}: "
f"{total_batches} batches ({batch_performance}), "
f"{total_examples} examples ({example_performance}), "
f"avg batch size {avg_batch_size:.2f}"
)
def report_stats(self, pool_uid, batch_size, processing_time):
batch_stats = BatchStats(batch_size, processing_time)
self.stats_queue.put_nowait((pool_uid, batch_stats))

@ -25,7 +25,10 @@ def initialize_logs():
os.environ["HIVEMIND_COLORS"] = "True"
importlib.reload(hm_logging)
hm_logging.get_logger().handlers.clear() # Remove extra default handlers on Colab
# Remove log handlers from previous import of hivemind.utils.logging and extra handlers on Colab
hm_logging.get_logger().handlers.clear()
hm_logging.get_logger("hivemind").handlers.clear()
hm_logging.use_hivemind_log_handler("in_root_logger")
# We suppress asyncio error logs by default since they are mostly not relevant for the end user,

@ -3,8 +3,8 @@ import time
import pytest
import torch
from hivemind.moe.server.runtime import Runtime
from petals.server.runtime import Runtime
from petals.server.task_pool import PrioritizedTaskPool

Loading…
Cancel
Save