Clean up disk space (#152)

pull/153/head
Alexander Borzunov 1 year ago committed by GitHub
parent b04982c1a2
commit 701ec7e53e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,6 +8,8 @@ If necessary, one can rewrite this to implement a different behavior, such as:
"""
from __future__ import annotations
import itertools
import time
from typing import Optional, OrderedDict, Union
import torch
@ -17,7 +19,8 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.utils import get_file_from_repo
from petals.bloom.block import WrappedBloomBlock
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.server.block_utils import get_block_size
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -33,6 +36,7 @@ def load_pretrained_block(
torch_dtype: Union[torch.dtype, str] = "auto",
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
) -> WrappedBloomBlock:
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
@ -43,7 +47,12 @@ def load_pretrained_block(
block = WrappedBloomBlock(config)
state_dict = _load_state_dict(
converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
converted_model_name_or_path,
block_index,
config,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
if torch_dtype == "auto":
@ -62,20 +71,56 @@ def load_pretrained_block(
def _load_state_dict(
pretrained_model_name_or_path: str,
block_index: Optional[int] = None,
block_index: int,
config: BloomConfig,
*,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
min_backoff: float = 5,
) -> OrderedDict[str, torch.Tensor]:
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
archive_file = get_file_from_repo(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
revision=revision,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
)
state_dict = torch.load(archive_file, map_location="cpu")
return state_dict
revision = BLOCK_BRANCH_PREFIX + str(block_index)
# First, try to find the weights locally
try:
with allow_cache_reads(cache_dir):
archive_file = get_file_from_repo(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
revision=revision,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=True,
)
if archive_file is not None:
return torch.load(archive_file, map_location="cpu")
except Exception:
logger.debug(
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
)
# If not found, ensure that we have enough disk space to download them (maybe remove something)
for attempt_no in itertools.count():
try:
with allow_cache_writes(cache_dir):
block_size = get_block_size(config, "disk")
free_disk_space_for(
pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
)
archive_file = get_file_from_repo(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
revision=revision,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=False,
)
return torch.load(archive_file, map_location="cpu")
except Exception as e:
delay = min_backoff * (2**attempt_no)
logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
time.sleep(delay)
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

@ -47,8 +47,18 @@ def main():
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=2048,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument("--max_disk_space", type=str, default=None,
help="Maximal disk space used for caches. Example: 50GB, 100GiB (GB != GiB here). "
"Default: unlimited. "
"For bigscience/bloom-petals, this default means that the server may use up to "
"min(free_disk_space, 350GB) in the worst case, which happens when the server runs "
"for a long time and caches all model blocks after a number of rebalancings. "
"However, this worst case is unlikely, expect the server to consume "
"the disk space equal to 2-4x of your GPU memory on average.")
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, default="auto",
@ -129,7 +139,14 @@ def main():
attn_cache_size = parse_size(attn_cache_size)
assert isinstance(
attn_cache_size, (int, type(None))
), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
), "Unrecognized value for --attn_cache_size. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
max_disk_space = args.pop("max_disk_space")
if max_disk_space is not None:
max_disk_space = parse_size(max_disk_space)
assert isinstance(
max_disk_space, (int, type(None))
), "Unrecognized value for --max_disk_space. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
if args.pop("new_swarm"):
args["initial_peers"] = []
@ -138,7 +155,7 @@ def main():
if load_in_8bit is not None:
args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
server = Server(**args, compression=compression, max_disk_space=max_disk_space, attn_cache_size=attn_cache_size)
try:
server.run()
except KeyboardInterrupt:

@ -54,6 +54,14 @@ class TransformerConnectionHandler(ConnectionHandler):
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self._prioritizer = task_prioritizer
def shutdown(self):
if self.is_alive():
self._outer_pipe.send("_shutdown")
self.join(self.shutdown_timeout)
if self.is_alive():
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
self.terminate()
async def _gather_inputs(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> Tuple[str, List[torch.Tensor], Dict]:

@ -29,6 +29,7 @@ from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.throughput import get_host_throughput
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -56,6 +57,7 @@ class Server:
torch_dtype: str = "auto",
revision: str = "main",
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
attn_cache_size: Optional[int] = None,
alloc_timeout: float = 60,
device: Optional[Union[str, torch.device]] = None,
@ -82,7 +84,6 @@ class Server:
self.num_handlers = num_handlers
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.cache_dir = cache_dir
self.compression = compression
self.stats_report_interval, self.update_period = stats_report_interval, update_period
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
@ -117,7 +118,8 @@ class Server:
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info("Connecting to the public Petals swarm")
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
logger.info("Please check that your server is reachable at http://health.petals.ml")
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@ -158,6 +160,11 @@ class Server:
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
@ -213,6 +220,7 @@ class Server:
inference_max_length=self.inference_max_length,
torch_dtype=self.torch_dtype,
cache_dir=self.cache_dir,
max_disk_space=self.max_disk_space,
device=self.device,
compression=self.compression,
stats_report_interval=self.stats_report_interval,
@ -308,7 +316,8 @@ class ModuleContainer(threading.Thread):
min_batch_size: int,
max_batch_size: int,
torch_dtype: torch.dtype,
cache_dir: Optional[str],
cache_dir: str,
max_disk_space: int,
device: Union[str, torch.device],
compression: CompressionType,
update_period: float,
@ -340,6 +349,7 @@ class ModuleContainer(threading.Thread):
torch_dtype=torch_dtype,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
if load_in_8bit:

@ -1,4 +1,86 @@
import fcntl
import os
import shutil
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
import huggingface_hub
from hivemind.utils.logging import get_logger
logger = get_logger(__file__)
DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
BLOCKS_LOCK_FILE = "blocks.lock"
@contextmanager
def _blocks_lock(cache_dir: Optional[str], mode: int):
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb") as lock_fd:
fcntl.flock(lock_fd.fileno(), mode)
# The OS will release the lock when lock_fd is closed or the process is killed
yield
def allow_cache_reads(cache_dir: Optional[str]):
"""Allows simultaneous reads, guarantees that blocks won't be removed along the way (shared lock)"""
return _blocks_lock(cache_dir, fcntl.LOCK_SH)
def allow_cache_writes(
cache_dir: Optional[str], *, reserve: Optional[int] = None, max_disk_space: Optional[int] = None
):
"""Allows saving new blocks and removing the old ones (exclusive lock)"""
return _blocks_lock(cache_dir, fcntl.LOCK_EX)
def free_disk_space_for(
model_name: str,
size: int,
*,
cache_dir: Optional[str],
max_disk_space: Optional[int],
os_quota: int = 1024**3, # Minimal space we should leave to keep OS function normally
):
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
cache_info = huggingface_hub.scan_cache_dir(cache_dir)
model_repos = [repo for repo in cache_info.repos if repo.repo_type == "model" and repo.repo_id == model_name]
occupied_space = sum(repo.size_on_disk for repo in model_repos)
available_space = shutil.disk_usage(cache_dir).free - os_quota
if max_disk_space is not None:
available_space = min(available_space, max_disk_space - occupied_space)
if size <= available_space:
return
revisions = [revision for repo in model_repos for revision in repo.revisions]
revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
# Remove as few least recently used blocks as possible
pending_removal = []
freed_space = 0
extra_space_needed = size - available_space
for rev in revisions:
pending_removal.append(rev.commit_hash)
freed_space += rev.size_on_disk
if freed_space >= extra_space_needed:
break
if pending_removal:
gib = 1024**3
logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space")
delete_strategy = cache_info.delete_revisions(*pending_removal)
delete_strategy.execute()
if freed_space < extra_space_needed:
raise RuntimeError(
f"Insufficient disk space to load a block. Please free {extra_space_needed - freed_space:.1f} GiB "
f"on the volume for {cache_dir} or increase --max_disk_space if you set it manually"
)

Loading…
Cancel
Save