diff --git a/Dockerfile b/Dockerfile index 5de479d..d32e93a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,9 +19,12 @@ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh - ENV PATH="/opt/conda/bin:${PATH}" RUN conda install python~=3.10 pip && \ - pip install --no-cache-dir "torch>=1.12" torchvision torchaudio && \ + pip install --no-cache-dir "torch>=1.12" && \ conda clean --all && rm -rf ~/.cache/pip +VOLUME /cache +ENV PETALS_CACHE=/cache + COPY . petals/ RUN pip install -e petals[dev] diff --git a/README.md b/README.md index b1dab00..c834a4f 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ Connect your own GPU and increase Petals capacity: (conda) $ python -m petals.cli.run_server bigscience/bloom-petals # Or using a GPU-enabled Docker image -sudo docker run --net host --ipc host --gpus all --rm learningathome/petals:main \ +sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \ python -m petals.cli.run_server bigscience/bloom-petals ``` diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 57f1129..16a4e72 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -16,6 +16,7 @@ from transformers.modeling_utils import WEIGHTS_NAME from transformers.utils.hub import cached_path, hf_bucket_url from petals.bloom import BloomBlock, BloomConfig +from petals.utils.disk_cache import DEFAULT_CACHE_DIR use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -37,8 +38,12 @@ def load_pretrained_block( cache_dir: Optional[str] = None, ) -> BloomBlock: """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it.""" + if config is None: config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) + if cache_dir is None: + cache_dir = DEFAULT_CACHE_DIR + block = BloomBlock(config, layer_number=block_index) state_dict = _load_state_dict( converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index c29851c..4fadda2 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -7,7 +7,7 @@ from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger from petals.bloom.from_pretrained import BloomBlock -from petals.server.cache import MemoryCache +from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy diff --git a/src/petals/server/cache.py b/src/petals/server/memory_cache.py similarity index 100% rename from src/petals/server/cache.py rename to src/petals/server/memory_cache.py diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 0d148ad..7c12f48 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -24,8 +24,8 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend -from petals.server.cache import MemoryCache from petals.server.handler import TransformerConnectionHandler +from petals.server.memory_cache import MemoryCache from petals.server.throughput import get_host_throughput from petals.utils.convert_8bit import replace_8bit_linear @@ -160,7 +160,12 @@ class Server: assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: throughput = get_host_throughput( - self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval") + self.block_config, + device, + torch_dtype, + load_in_8bit=load_in_8bit, + force_eval=(throughput == "eval"), + cache_dir=cache_dir, ) self.throughput = throughput diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index c4a9899..fc08eba 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -2,11 +2,10 @@ import fcntl import json import os import subprocess -import tempfile import time from hashlib import sha256 from pathlib import Path -from typing import Union +from typing import Optional, Union import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler @@ -15,15 +14,12 @@ from petals.bloom.block import BloomBlock from petals.bloom.model import BloomConfig from petals.bloom.ops import build_alibi_tensor from petals.utils.convert_8bit import replace_8bit_linear +from petals.utils.disk_cache import DEFAULT_CACHE_DIR use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput_v2.json") -DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock") - - def get_host_throughput( config: BloomConfig, device: torch.device, @@ -31,8 +27,7 @@ def get_host_throughput( *, load_in_8bit: bool, force_eval: bool = False, - cache_path: str = DEFAULT_CACHE_PATH, - lock_path: str = DEFAULT_LOCK_PATH, + cache_dir: Optional[str] = None, ) -> float: # Resolve default dtypes if dtype == "auto" or dtype is None: @@ -40,6 +35,11 @@ def get_host_throughput( if dtype == "auto" or dtype is None: dtype = torch.float32 + if cache_dir is None: + cache_dir = DEFAULT_CACHE_DIR + lock_path = Path(cache_dir, "throughput.lock") + cache_path = Path(cache_dir, "throughput_v2.json") + # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) with open(lock_path, "wb") as lock_fd: diff --git a/src/petals/utils/disk_cache.py b/src/petals/utils/disk_cache.py new file mode 100644 index 0000000..9586c6d --- /dev/null +++ b/src/petals/utils/disk_cache.py @@ -0,0 +1,4 @@ +import os +from pathlib import Path + +DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))