Fix issues related to `petals` as a module (#159)

1. Added `from petals.client import *` to `petals/__init__.py`, so you can write just that:

    ```python
    from petals import DistributedBloomForCausalLM
    ```

    I didn't do the same with server, since its classes are supposed to by used by `petals.cli.run_server`, not end-users. Though it's still possible to do `from petals.server.smth import smth` if necessary.

2. Fixed one more logging issue: log lines from hivemind were shown twice due to a bug in #156.

3. Removed unused `runtime.py`, since the server actually uses `hivemind.moe.Runtime`, and `runtime.py` has no significant changes comparing to it.
pull/160/head
Alexander Borzunov 1 year ago committed by GitHub
parent 91898c3c90
commit 523a7cad33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@
Generate text using distributed BLOOM and fine-tune it for your own tasks: Generate text using distributed BLOOM and fine-tune it for your own tasks:
```python ```python
from petals.client import DistributedBloomForCausalLM from petals import DistributedBloomForCausalLM
# Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet # Embeddings & prompts are on your device, BLOOM blocks are distributed across the Internet
model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune") model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", tuning_mode="ptune")
@ -68,13 +68,13 @@ Check out more tutorials:
📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b> 📜 &nbsp;<b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
</p> </p>
### 📋 Model's terms of use ### 🔒 Privacy and security
Before building your own application that runs a language model with Petals, please make sure that you are familiar with the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license). The Petals public swarm is designed for research and academic use. **Please do not use the public swarm to process sensitive data.** We ask for that because it is an open network, and it is technically possible for peers serving model layers to recover input data and model outputs or modify them in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process your data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
### 🔒 Privacy and security ### 📋 Model's terms of use
**If you work with sensitive data, do not use the public swarm.** This is important because it's technically possible for peers serving model layers to recover input data and model outputs, or modify the outputs in a malicious way. Instead, you can [set up a private Petals swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) hosted by people and organization you trust, who are authorized to process this data. We discuss privacy and security in more detail [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). Before building your own application that runs a language model with Petals, please check out the model's **terms of use, risks, and limitations**. In case of BLOOM, they are described in its [model card](https://huggingface.co/bigscience/bloom) and [license](https://huggingface.co/spaces/bigscience/license).
## FAQ ## FAQ

@ -48,7 +48,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
" \n", "\n",
"import torch\n", "import torch\n",
"import transformers\n", "import transformers\n",
"import wandb\n", "import wandb\n",
@ -58,8 +58,7 @@
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from transformers import BloomTokenizerFast, get_scheduler\n", "from transformers import BloomTokenizerFast, get_scheduler\n",
"\n", "\n",
"# Import a Petals model\n", "from petals import DistributedBloomForCausalLM"
"from petals.client.remote_model import DistributedBloomForCausalLM"
] ]
}, },
{ {

@ -48,22 +48,19 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
" \n",
"import torch\n",
"import transformers\n",
"import wandb\n",
"\n", "\n",
"import torch\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import torch.nn.functional as F\n", "import torch.nn.functional as F\n",
"\n", "import transformers\n",
"import wandb\n",
"from datasets import load_dataset, load_metric\n", "from datasets import load_dataset, load_metric\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"from torch.optim import AdamW\n", "from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from transformers import BloomTokenizerFast, get_scheduler\n", "from transformers import BloomTokenizerFast, get_scheduler\n",
"\n", "\n",
"# Import a Petals model\n", "from petals import DistributedBloomForSequenceClassification"
"from petals.client.remote_model import DistributedBloomForSequenceClassification"
] ]
}, },
{ {

@ -1,5 +1,6 @@
import petals.utils.logging from petals.client import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "1.0alpha1" __version__ = "1.0alpha1"
petals.utils.logging.initialize_logs() _initialize_logs()

@ -1,5 +1,10 @@
from petals.client.inference_session import InferenceSession from petals.client.inference_session import InferenceSession
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel from petals.client.remote_model import (
DistributedBloomConfig,
DistributedBloomForCausalLM,
DistributedBloomForSequenceClassification,
DistributedBloomModel,
)
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -1,198 +0,0 @@
import multiprocessing as mp
import multiprocessing.pool
import threading
from collections import defaultdict
from itertools import chain
from queue import SimpleQueue
from selectors import EVENT_READ, DefaultSelector
from statistics import mean
from time import time
from typing import Dict, NamedTuple, Optional
import torch
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from prefetch_generator import BackgroundGenerator
logger = get_logger(__name__)
class Runtime(threading.Thread):
"""
A group of processes that processes incoming requests for multiple module backends on a shared device.
Runtime is usually created and managed by Server, humans need not apply.
For debugging, you can start runtime manually with .start() or .run()
>>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
>>> runtime = Runtime(module_backends)
>>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
>>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
>>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
>>> print("Returned:", future.result())
>>> runtime.shutdown()
:param module_backends: a dict [block uid -> ModuleBackend]
:param prefetch_batches: form up to this many batches in advance
:param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
:param device: if specified, moves all blocks and data to this device via .to(device=device).
If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
:param stats_report_interval: interval to collect and log statistics about runtime performance
"""
SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
def __init__(
self,
module_backends: Dict[str, ModuleBackend],
prefetch_batches: int = 1,
sender_threads: int = 1,
device: torch.device = None,
stats_report_interval: Optional[int] = None,
):
super().__init__()
self.module_backends = module_backends
self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
self.shutdown_trigger = mp.Event()
self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
self.stats_report_interval = stats_report_interval
if self.stats_report_interval is not None:
self.stats_reporter = StatsReporter(self.stats_report_interval)
def run(self):
for pool in self.pools:
if not pool.is_alive():
pool.start()
if self.device is not None:
for backend in self.module_backends.values():
backend.module.to(self.device)
with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
try:
self.ready.set()
if self.stats_report_interval is not None:
self.stats_reporter.start()
logger.info("Started")
batch_iterator = self.iterate_minibatches_from_pools()
if self.prefetch_batches > 0:
batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
for pool, batch_index, batch in batch_iterator:
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
start = time()
try:
outputs = pool.process_func(*batch)
output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
batch_processing_time = time() - start
batch_size = outputs[0].size(0)
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
except KeyboardInterrupt:
raise
except BaseException as exception:
logger.exception(f"Caught {exception}, attempting to recover")
output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
finally:
if not self.shutdown_trigger.is_set():
self.shutdown()
def shutdown(self):
"""Gracefully terminate a running runtime."""
logger.info("Shutting down")
self.ready.clear()
if self.stats_report_interval is not None:
self.stats_reporter.stop.set()
self.stats_reporter.join()
logger.debug("Terminating pools")
for pool in self.pools:
if pool.is_alive():
pool.shutdown()
logger.debug("Pools terminated")
# trigger background thread to shutdown
self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
self.shutdown_trigger.set()
def iterate_minibatches_from_pools(self, timeout=None):
"""
Chooses pool according to priority, then copies exposed batch and frees the buffer
"""
with DefaultSelector() as selector:
for pool in self.pools:
selector.register(pool.batch_receiver, EVENT_READ, pool)
selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
while True:
# wait until at least one batch_receiver becomes available
logger.debug("Waiting for inputs from task pools")
ready_fds = selector.select()
ready_objects = {key.data for (key, events) in ready_fds}
if self.SHUTDOWN_TRIGGER in ready_objects:
break # someone asked us to shutdown, break from the loop
logger.debug("Choosing the pool with first priority")
pool = min(ready_objects, key=lambda pool: pool.priority)
logger.debug(f"Loading batch from {pool.name}")
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
logger.debug(f"Loaded batch from {pool.name}")
yield pool, batch_index, batch_tensors
BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
class StatsReporter(threading.Thread):
def __init__(self, report_interval: int):
super().__init__()
self.report_interval = report_interval
self.stop = threading.Event()
self.stats_queue = SimpleQueue()
def run(self):
while not self.stop.wait(self.report_interval):
pool_batch_stats = defaultdict(list)
while not self.stats_queue.empty():
pool_uid, batch_stats = self.stats_queue.get()
pool_batch_stats[pool_uid].append(batch_stats)
total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
for pool_uid, pool_stats in pool_batch_stats.items():
total_batches = len(pool_stats)
total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
batches_to_time = total_batches / total_time
batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
examples_to_time = total_examples / total_time
example_performance = f"{examples_to_time:.2f} " + (
"examples/s" if examples_to_time > 1 else "s/example"
)
logger.info(
f"{pool_uid}: "
f"{total_batches} batches ({batch_performance}), "
f"{total_examples} examples ({example_performance}), "
f"avg batch size {avg_batch_size:.2f}"
)
def report_stats(self, pool_uid, batch_size, processing_time):
batch_stats = BatchStats(batch_size, processing_time)
self.stats_queue.put_nowait((pool_uid, batch_stats))

@ -25,7 +25,10 @@ def initialize_logs():
os.environ["HIVEMIND_COLORS"] = "True" os.environ["HIVEMIND_COLORS"] = "True"
importlib.reload(hm_logging) importlib.reload(hm_logging)
hm_logging.get_logger().handlers.clear() # Remove extra default handlers on Colab # Remove log handlers from previous import of hivemind.utils.logging and extra handlers on Colab
hm_logging.get_logger().handlers.clear()
hm_logging.get_logger("hivemind").handlers.clear()
hm_logging.use_hivemind_log_handler("in_root_logger") hm_logging.use_hivemind_log_handler("in_root_logger")
# We suppress asyncio error logs by default since they are mostly not relevant for the end user, # We suppress asyncio error logs by default since they are mostly not relevant for the end user,

@ -3,8 +3,8 @@ import time
import pytest import pytest
import torch import torch
from hivemind.moe.server.runtime import Runtime
from petals.server.runtime import Runtime
from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_pool import PrioritizedTaskPool

Loading…
Cancel
Save