You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/tests/test_peft.py

67 lines
1.5 KiB
Python

import os
import shutil
import pytest
from huggingface_hub import snapshot_download
from petals.utils.peft import check_peft_repository, load_peft
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
TMP_CACHE_DIR = "tmp_cache/"
def clear_dir(path_to_dir):
shutil.rmtree(path_to_dir)
os.mkdir(path_to_dir)
def dir_empty(path_to_dir):
files = os.listdir(path_to_dir)
return len(files) == 0
@pytest.mark.forked
def test_check_peft():
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
@pytest.mark.forked
def test_load_noncached(tmpdir):
clear_dir(tmpdir)
with pytest.raises(Exception):
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
@pytest.mark.forked
def test_load_cached(tmpdir):
clear_dir(tmpdir)
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
@pytest.mark.forked
def test_load_layer_exists(tmpdir):
clear_dir(tmpdir)
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
@pytest.mark.forked
def test_load_layer_nonexists(tmpdir):
clear_dir(tmpdir)
load_peft(
SAFE_PEFT_REPO,
block_idx=1337,
cache_dir=tmpdir,
)