Use common folder for all caches, make it a volume in Dockerfile (#141)

pull/142/head
Alexander Borzunov 1 year ago committed by GitHub
parent 5f50ea9c79
commit e99bf36647
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,9 +19,12 @@ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -
ENV PATH="/opt/conda/bin:${PATH}"
RUN conda install python~=3.10 pip && \
pip install --no-cache-dir "torch>=1.12" torchvision torchaudio && \
pip install --no-cache-dir "torch>=1.12" && \
conda clean --all && rm -rf ~/.cache/pip
VOLUME /cache
ENV PETALS_CACHE=/cache
COPY . petals/
RUN pip install -e petals[dev]

@ -40,7 +40,7 @@ Connect your own GPU and increase Petals capacity:
(conda) $ python -m petals.cli.run_server bigscience/bloom-petals
# Or using a GPU-enabled Docker image
sudo docker run --net host --ipc host --gpus all --rm learningathome/petals:main \
sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
python -m petals.cli.run_server bigscience/bloom-petals
```

@ -16,6 +16,7 @@ from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url
from petals.bloom import BloomBlock, BloomConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -37,8 +38,12 @@ def load_pretrained_block(
cache_dir: Optional[str] = None,
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
block = BloomBlock(config, layer_number=block_index)
state_dict = _load_state_dict(
converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir

@ -7,7 +7,7 @@ from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from petals.bloom.from_pretrained import BloomBlock
from petals.server.cache import MemoryCache
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy

@ -24,8 +24,8 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend
from petals.server.cache import MemoryCache
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.throughput import get_host_throughput
from petals.utils.convert_8bit import replace_8bit_linear
@ -160,7 +160,12 @@ class Server:
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
self.block_config,
device,
torch_dtype,
load_in_8bit=load_in_8bit,
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
)
self.throughput = throughput

@ -2,11 +2,10 @@ import fcntl
import json
import os
import subprocess
import tempfile
import time
from hashlib import sha256
from pathlib import Path
from typing import Union
from typing import Optional, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@ -15,15 +14,12 @@ from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput_v2.json")
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
def get_host_throughput(
config: BloomConfig,
device: torch.device,
@ -31,8 +27,7 @@ def get_host_throughput(
*,
load_in_8bit: bool,
force_eval: bool = False,
cache_path: str = DEFAULT_CACHE_PATH,
lock_path: str = DEFAULT_LOCK_PATH,
cache_dir: Optional[str] = None,
) -> float:
# Resolve default dtypes
if dtype == "auto" or dtype is None:
@ -40,6 +35,11 @@ def get_host_throughput(
if dtype == "auto" or dtype is None:
dtype = torch.float32
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, "throughput.lock")
cache_path = Path(cache_dir, "throughput_v2.json")
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb") as lock_fd:

@ -0,0 +1,4 @@
import os
from pathlib import Path
DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
Loading…
Cancel
Save