Reset MemoryCache during rebalancings (#154)

Before this PR, if there were open inference sessions right when rebalancing is triggered, their cache was never properly destroyed.
pull/148/head
Alexander Borzunov 1 year ago committed by GitHub
parent bd91be27ea
commit 73df69a117
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,12 +33,10 @@ class MemoryCache:
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
self.runtime_pid = os.getpid()
self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
self._lock_acquire_memory = mp.Lock()
self._memory_freed_event = mp.Event()
@ -83,14 +81,12 @@ class MemoryCache:
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, descr))
yield allocated_handle
finally:
if allocated_handle is not None:
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
self.current_size_bytes -= allocated_size_bytes
self._memory_freed_event.set()
@ -122,13 +118,9 @@ class MemoryCache:
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
with self._lock_metadata:
if self._allocated_tensors is None:
self._allocated_tensors = {}
# read creation/deletion requests from connection handlers
for i in range(int(self._pending_messages.value)):
while self._pipe_recv.poll():
recv_handle, recv_data = self._pipe_recv.recv()
self._pending_messages.value -= 1
if isinstance(recv_data, TensorDescriptor):
self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
elif recv_data is None:

@ -1,6 +1,7 @@
from __future__ import annotations
import gc
import itertools
import math
import multiprocessing as mp
import random
@ -72,8 +73,8 @@ class Server:
prefetch_batches: int = 1,
sender_threads: int = 1,
balance_quality: float = 0.75,
mean_balance_check_period: float = 60,
mean_block_selection_delay: float = 0.5,
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5,
use_auth_token: Optional[str] = None,
load_in_8bit: Optional[bool] = None,
**kwargs,
@ -119,7 +120,6 @@ class Server:
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
logger.info("Please check that your server is reachable at http://health.petals.ml")
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@ -157,8 +157,8 @@ class Server:
if attn_cache_size is None:
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
@ -211,7 +211,8 @@ class Server:
prefix=self.prefix,
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
memory_cache=self.memory_cache,
attn_cache_size=self.attn_cache_size,
alloc_timeout=self.alloc_timeout,
throughput=self.throughput,
block_indices=block_indices,
num_handlers=self.num_handlers,
@ -310,7 +311,8 @@ class ModuleContainer(threading.Thread):
prefix: str,
converted_model_name_or_path: str,
block_config: BloomConfig,
memory_cache: MemoryCache,
attn_cache_size: int,
alloc_timeout: float,
throughput: float,
block_indices: List[int],
min_batch_size: int,
@ -339,8 +341,9 @@ class ModuleContainer(threading.Thread):
joining_announcer.start()
logger.info(f"Announced that blocks {block_indices} are joining")
memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
blocks = {}
try:
blocks = {}
for module_uid, block_index in zip(module_uids, block_indices):
block = load_pretrained_block(
converted_model_name_or_path,
@ -380,6 +383,10 @@ class ModuleContainer(threading.Thread):
max_batch_size=max_batch_size,
)
except:
logger.debug("Shutting down backends")
for backend in blocks.values():
backend.shutdown()
joining_announcer.stop.set()
joining_announcer.join()
declare_active_modules(
@ -563,7 +570,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.stop = threading.Event()
def run(self) -> None:
while True:
for iter_no in itertools.count():
declare_active_modules(
self.dht,
self.module_uids,
@ -571,5 +578,10 @@ class ModuleAnnouncerThread(threading.Thread):
state=self.state,
throughput=self.throughput,
)
if iter_no == 0 and self.state == ServerState.JOINING:
logger.info(
f"Please ensure that your server is reachable. "
f"For public swarm, open http://health.petals.ml and find peer_id = {self.dht.peer_id}"
)
if self.stop.wait(self.update_period):
break

Loading…
Cancel
Save