Fix issues related to `petals` as a module (#159)
1. Added `from petals.client import *` to `petals/__init__.py`, so you can write just that: ```python from petals import DistributedBloomForCausalLM ``` I didn't do the same with server, since its classes are supposed to by used by `petals.cli.run_server`, not end-users. Though it's still possible to do `from petals.server.smth import smth` if necessary. 2. Fixed one more logging issue: log lines from hivemind were shown twice due to a bug in #156. 3. Removed unused `runtime.py`, since the server actually uses `hivemind.moe.Runtime`, and `runtime.py` has no significant changes comparing to it.pull/160/head
parent
91898c3c90
commit
523a7cad33
@ -1,5 +1,6 @@
|
||||
import petals.utils.logging
|
||||
from petals.client import *
|
||||
from petals.utils.logging import initialize_logs as _initialize_logs
|
||||
|
||||
__version__ = "1.0alpha1"
|
||||
|
||||
petals.utils.logging.initialize_logs()
|
||||
_initialize_logs()
|
||||
|
@ -1,5 +1,10 @@
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
|
||||
from petals.client.remote_model import (
|
||||
DistributedBloomConfig,
|
||||
DistributedBloomForCausalLM,
|
||||
DistributedBloomForSequenceClassification,
|
||||
DistributedBloomModel,
|
||||
)
|
||||
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
|
@ -1,198 +0,0 @@
|
||||
import multiprocessing as mp
|
||||
import multiprocessing.pool
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from queue import SimpleQueue
|
||||
from selectors import EVENT_READ, DefaultSelector
|
||||
from statistics import mean
|
||||
from time import time
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from hivemind.moe.server.module_backend import ModuleBackend
|
||||
from hivemind.utils import get_logger
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Runtime(threading.Thread):
|
||||
"""
|
||||
A group of processes that processes incoming requests for multiple module backends on a shared device.
|
||||
Runtime is usually created and managed by Server, humans need not apply.
|
||||
|
||||
For debugging, you can start runtime manually with .start() or .run()
|
||||
|
||||
>>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
|
||||
>>> runtime = Runtime(module_backends)
|
||||
>>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
|
||||
>>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
|
||||
>>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
|
||||
>>> print("Returned:", future.result())
|
||||
>>> runtime.shutdown()
|
||||
|
||||
:param module_backends: a dict [block uid -> ModuleBackend]
|
||||
:param prefetch_batches: form up to this many batches in advance
|
||||
:param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
|
||||
:param device: if specified, moves all blocks and data to this device via .to(device=device).
|
||||
If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
|
||||
|
||||
:param stats_report_interval: interval to collect and log statistics about runtime performance
|
||||
"""
|
||||
|
||||
SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_backends: Dict[str, ModuleBackend],
|
||||
prefetch_batches: int = 1,
|
||||
sender_threads: int = 1,
|
||||
device: torch.device = None,
|
||||
stats_report_interval: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.module_backends = module_backends
|
||||
self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
|
||||
self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
|
||||
self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
|
||||
self.shutdown_trigger = mp.Event()
|
||||
self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
|
||||
|
||||
self.stats_report_interval = stats_report_interval
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter = StatsReporter(self.stats_report_interval)
|
||||
|
||||
def run(self):
|
||||
for pool in self.pools:
|
||||
if not pool.is_alive():
|
||||
pool.start()
|
||||
if self.device is not None:
|
||||
for backend in self.module_backends.values():
|
||||
backend.module.to(self.device)
|
||||
|
||||
with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
|
||||
try:
|
||||
self.ready.set()
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter.start()
|
||||
logger.info("Started")
|
||||
|
||||
batch_iterator = self.iterate_minibatches_from_pools()
|
||||
if self.prefetch_batches > 0:
|
||||
batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
|
||||
|
||||
for pool, batch_index, batch in batch_iterator:
|
||||
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
|
||||
|
||||
start = time()
|
||||
try:
|
||||
outputs = pool.process_func(*batch)
|
||||
output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
|
||||
|
||||
batch_processing_time = time() - start
|
||||
|
||||
batch_size = outputs[0].size(0)
|
||||
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
|
||||
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except BaseException as exception:
|
||||
logger.exception(f"Caught {exception}, attempting to recover")
|
||||
output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
|
||||
|
||||
finally:
|
||||
if not self.shutdown_trigger.is_set():
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Gracefully terminate a running runtime."""
|
||||
logger.info("Shutting down")
|
||||
self.ready.clear()
|
||||
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter.stop.set()
|
||||
self.stats_reporter.join()
|
||||
|
||||
logger.debug("Terminating pools")
|
||||
for pool in self.pools:
|
||||
if pool.is_alive():
|
||||
pool.shutdown()
|
||||
logger.debug("Pools terminated")
|
||||
|
||||
# trigger background thread to shutdown
|
||||
self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
|
||||
self.shutdown_trigger.set()
|
||||
|
||||
def iterate_minibatches_from_pools(self, timeout=None):
|
||||
"""
|
||||
Chooses pool according to priority, then copies exposed batch and frees the buffer
|
||||
"""
|
||||
with DefaultSelector() as selector:
|
||||
for pool in self.pools:
|
||||
selector.register(pool.batch_receiver, EVENT_READ, pool)
|
||||
selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
|
||||
|
||||
while True:
|
||||
# wait until at least one batch_receiver becomes available
|
||||
logger.debug("Waiting for inputs from task pools")
|
||||
ready_fds = selector.select()
|
||||
ready_objects = {key.data for (key, events) in ready_fds}
|
||||
if self.SHUTDOWN_TRIGGER in ready_objects:
|
||||
break # someone asked us to shutdown, break from the loop
|
||||
|
||||
logger.debug("Choosing the pool with first priority")
|
||||
|
||||
pool = min(ready_objects, key=lambda pool: pool.priority)
|
||||
|
||||
logger.debug(f"Loading batch from {pool.name}")
|
||||
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
|
||||
logger.debug(f"Loaded batch from {pool.name}")
|
||||
yield pool, batch_index, batch_tensors
|
||||
|
||||
|
||||
BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
|
||||
|
||||
|
||||
class StatsReporter(threading.Thread):
|
||||
def __init__(self, report_interval: int):
|
||||
super().__init__()
|
||||
self.report_interval = report_interval
|
||||
self.stop = threading.Event()
|
||||
self.stats_queue = SimpleQueue()
|
||||
|
||||
def run(self):
|
||||
while not self.stop.wait(self.report_interval):
|
||||
pool_batch_stats = defaultdict(list)
|
||||
while not self.stats_queue.empty():
|
||||
pool_uid, batch_stats = self.stats_queue.get()
|
||||
pool_batch_stats[pool_uid].append(batch_stats)
|
||||
|
||||
total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
|
||||
logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
|
||||
for pool_uid, pool_stats in pool_batch_stats.items():
|
||||
total_batches = len(pool_stats)
|
||||
total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
|
||||
avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
|
||||
total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
|
||||
batches_to_time = total_batches / total_time
|
||||
batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
|
||||
|
||||
examples_to_time = total_examples / total_time
|
||||
example_performance = f"{examples_to_time:.2f} " + (
|
||||
"examples/s" if examples_to_time > 1 else "s/example"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{pool_uid}: "
|
||||
f"{total_batches} batches ({batch_performance}), "
|
||||
f"{total_examples} examples ({example_performance}), "
|
||||
f"avg batch size {avg_batch_size:.2f}"
|
||||
)
|
||||
|
||||
def report_stats(self, pool_uid, batch_size, processing_time):
|
||||
batch_stats = BatchStats(batch_size, processing_time)
|
||||
self.stats_queue.put_nowait((pool_uid, batch_stats))
|
Loading…
Reference in New Issue